-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathactivations.py
605 lines (468 loc) · 18.5 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in activation functions."""
from tensorflow.python.keras import backend
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
# b/123041942
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the
# internal method name is returned in serialization. This results in errors in
# model exporting and loading as Keras can't find any activation function with
# the name of `softmax_v2`.
# This dict maps the activation function name from its v2 version to its
# canonical name.
_TF_ACTIVATIONS_V2 = {
'softmax_v2': 'softmax',
}
@keras_export('keras.activations.softmax')
@dispatch.add_dispatch_support
def softmax(x, axis=-1):
"""Softmax converts a vector of values to a probability distribution.
The elements of the output vector are in range (0, 1) and sum to 1.
Each vector is handled independently. The `axis` argument sets which axis
of the input the function is applied along.
Softmax is often used as the activation for the last
layer of a classification network because the result could be interpreted as
a probability distribution.
The softmax of each vector x is computed as
`exp(x) / tf.reduce_sum(exp(x))`.
The input values in are the log-odds of the resulting probability.
Args:
x : Input tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
Tensor, output of softmax transformation (all values are non-negative
and sum to 1).
Examples:
**Example 1: standalone usage**
>>> inputs = tf.random.normal(shape=(32, 10))
>>> outputs = tf.keras.activations.softmax(inputs)
>>> tf.reduce_sum(outputs[0, :]) # Each sample in the batch now sums to 1
<tf.Tensor: shape=(), dtype=float32, numpy=1.0000001>
**Example 2: usage in a `Dense` layer**
>>> layer = tf.keras.layers.Dense(32, activation=tf.keras.activations.softmax)
"""
if x.shape.rank > 1:
if isinstance(axis, int):
output = nn.softmax(x, axis=axis)
else:
# nn.softmax does not support tuple axis.
e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
output = e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
'Received input: %s' % (x,))
# Cache the logits to use for crossentropy loss.
output._keras_logits = x # pylint: disable=protected-access
return output
@keras_export('keras.activations.elu')
@dispatch.add_dispatch_support
def elu(x, alpha=1.0):
"""Exponential Linear Unit.
The exponential linear unit (ELU) with `alpha > 0` is:
`x` if `x > 0` and
`alpha * (exp(x) - 1)` if `x < 0`
The ELU hyperparameter `alpha` controls the value to which an
ELU saturates for negative net inputs. ELUs diminish the
vanishing gradient effect.
ELUs have negative values which pushes the mean of the activations
closer to zero.
Mean activations that are closer to zero enable faster learning as they
bring the gradient closer to the natural gradient.
ELUs saturate to a negative value when the argument gets smaller.
Saturation means a small derivative which decreases the variation
and the information that is propagated to the next layer.
Example Usage:
>>> import tensorflow as tf
>>> model = tf.keras.Sequential()
>>> model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='elu',
... input_shape=(28, 28, 1)))
>>> model.add(tf.keras.layers.MaxPooling2D((2, 2)))
>>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu'))
>>> model.add(tf.keras.layers.MaxPooling2D((2, 2)))
>>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu'))
<tensorflow.python.keras.engine.sequential.Sequential object ...>
Args:
x: Input tensor.
alpha: A scalar, slope of negative section. `alpha` controls the value to
which an ELU saturates for negative net inputs.
Returns:
The exponential linear unit (ELU) activation function: `x` if `x > 0` and
`alpha * (exp(x) - 1)` if `x < 0`.
Reference:
[Fast and Accurate Deep Network Learning by Exponential Linear Units
(ELUs) (Clevert et al, 2016)](https://arxiv.org/abs/1511.07289)
"""
return backend.elu(x, alpha)
@keras_export('keras.activations.selu')
@dispatch.add_dispatch_support
def selu(x):
"""Scaled Exponential Linear Unit (SELU).
The Scaled Exponential Linear Unit (SELU) activation function is defined as:
- `if x > 0: return scale * x`
- `if x < 0: return scale * alpha * (exp(x) - 1)`
where `alpha` and `scale` are pre-defined constants
(`alpha=1.67326324` and `scale=1.05070098`).
Basically, the SELU activation function multiplies `scale` (> 1) with the
output of the `tf.keras.activations.elu` function to ensure a slope larger
than one for positive inputs.
The values of `alpha` and `scale` are
chosen so that the mean and variance of the inputs are preserved
between two consecutive layers as long as the weights are initialized
correctly (see `tf.keras.initializers.LecunNormal` initializer)
and the number of input units is "large enough"
(see reference paper for more information).
Example Usage:
>>> num_classes = 10 # 10-class problem
>>> model = tf.keras.Sequential()
>>> model.add(tf.keras.layers.Dense(64, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(32, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(16, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
Args:
x: A tensor or variable to compute the activation function for.
Returns:
The scaled exponential unit activation: `scale * elu(x, alpha)`.
Notes:
- To be used together with the
`tf.keras.initializers.LecunNormal` initializer.
- To be used together with the dropout variant
`tf.keras.layers.AlphaDropout` (not regular dropout).
References:
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
"""
return nn.selu(x)
@keras_export('keras.activations.softplus')
@dispatch.add_dispatch_support
def softplus(x):
"""Softplus activation function, `softplus(x) = log(exp(x) + 1)`.
Example Usage:
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.softplus(a)
>>> b.numpy()
array([2.0611537e-09, 3.1326166e-01, 6.9314718e-01, 1.3132616e+00,
2.0000000e+01], dtype=float32)
Args:
x: Input tensor.
Returns:
The softplus activation: `log(exp(x) + 1)`.
"""
return math_ops.softplus(x)
@keras_export('keras.activations.softsign')
@dispatch.add_dispatch_support
def softsign(x):
"""Softsign activation function, `softsign(x) = x / (abs(x) + 1)`.
Example Usage:
>>> a = tf.constant([-1.0, 0.0, 1.0], dtype = tf.float32)
>>> b = tf.keras.activations.softsign(a)
>>> b.numpy()
array([-0.5, 0. , 0.5], dtype=float32)
Args:
x: Input tensor.
Returns:
The softsign activation: `x / (abs(x) + 1)`.
"""
return nn.softsign(x)
@keras_export('keras.activations.swish')
@dispatch.add_dispatch_support
def swish(x):
"""Swish activation function, `swish(x) = x * sigmoid(x)`.
Swish activation function which returns `x*sigmoid(x)`.
It is a smooth, non-monotonic function that consistently matches
or outperforms ReLU on deep networks, it is unbounded above and
bounded below.
Example Usage:
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.swish(a)
>>> b.numpy()
array([-4.1223075e-08, -2.6894143e-01, 0.0000000e+00, 7.3105860e-01,
2.0000000e+01], dtype=float32)
Args:
x: Input tensor.
Returns:
The swish activation applied to `x` (see reference paper for details).
Reference:
- [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941)
"""
return nn.swish(x)
@keras_export('keras.activations.relu')
@dispatch.add_dispatch_support
def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function.
With default values, this returns the standard ReLU activation:
`max(x, 0)`, the element-wise maximum of 0 and the input tensor.
Modifying default parameters allows you to use non-zero thresholds,
change the max value of the activation,
and to use a non-zero multiple of the input for values below the threshold.
For example:
>>> foo = tf.constant([-10, -5, 0.0, 5, 10], dtype = tf.float32)
>>> tf.keras.activations.relu(foo).numpy()
array([ 0., 0., 0., 5., 10.], dtype=float32)
>>> tf.keras.activations.relu(foo, alpha=0.5).numpy()
array([-5. , -2.5, 0. , 5. , 10. ], dtype=float32)
>>> tf.keras.activations.relu(foo, max_value=5).numpy()
array([0., 0., 0., 5., 5.], dtype=float32)
>>> tf.keras.activations.relu(foo, threshold=5).numpy()
array([-0., -0., 0., 0., 10.], dtype=float32)
Args:
x: Input `tensor` or `variable`.
alpha: A `float` that governs the slope for values lower than the
threshold.
max_value: A `float` that sets the saturation threshold (the largest value
the function will return).
threshold: A `float` giving the threshold value of the activation function
below which values will be damped or set to zero.
Returns:
A `Tensor` representing the input tensor,
transformed by the relu activation function.
Tensor will be of the same shape and dtype of input `x`.
"""
return backend.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
@keras_export('keras.activations.gelu', v1=[])
@dispatch.add_dispatch_support
def gelu(x, approximate=False):
"""Applies the Gaussian error linear unit (GELU) activation function.
Gaussian error linear unit (GELU) computes
`x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
The (GELU) nonlinearity weights inputs by their value, rather than gates
inputs by their sign as in ReLU.
For example:
>>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
>>> y = tf.keras.activations.gelu(x)
>>> y.numpy()
array([-0.00404951, -0.15865529, 0. , 0.8413447 , 2.9959507 ],
dtype=float32)
>>> y = tf.keras.activations.gelu(x, approximate=True)
>>> y.numpy()
array([-0.00363752, -0.15880796, 0. , 0.841192 , 2.9963627 ],
dtype=float32)
Args:
x: Input tensor.
approximate: A `bool`, whether to enable approximation.
Returns:
The gaussian error linear activation:
`0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
if `approximate` is `True` or
`x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,
where `P(X) ~ N(0, 1)`,
if `approximate` is `False`.
Reference:
- [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)
"""
return nn.gelu(x, approximate)
@keras_export('keras.activations.tanh')
@dispatch.add_dispatch_support
def tanh(x):
"""Hyperbolic tangent activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.tanh(a)
>>> b.numpy()
array([-0.9950547, -0.7615942, 0., 0.7615942, 0.9950547], dtype=float32)
Args:
x: Input tensor.
Returns:
Tensor of same shape and dtype of input `x`, with tanh activation:
`tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`.
"""
return nn.tanh(x)
@keras_export('keras.activations.sigmoid')
@dispatch.add_dispatch_support
def sigmoid(x):
"""Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`.
Applies the sigmoid activation function. For small values (<-5),
`sigmoid` returns a value close to zero, and for large values (>5)
the result of the function gets close to 1.
Sigmoid is equivalent to a 2-element Softmax, where the second element is
assumed to be zero. The sigmoid function always returns a value between
0 and 1.
For example:
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.sigmoid(a)
>>> b.numpy()
array([2.0611537e-09, 2.6894143e-01, 5.0000000e-01, 7.3105860e-01,
1.0000000e+00], dtype=float32)
Args:
x: Input tensor.
Returns:
Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
"""
output = nn.sigmoid(x)
# Cache the logits to use for crossentropy loss.
output._keras_logits = x # pylint: disable=protected-access
return output
@keras_export('keras.activations.exponential')
@dispatch.add_dispatch_support
def exponential(x):
"""Exponential activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.exponential(a)
>>> b.numpy()
array([0.04978707, 0.36787945, 1., 2.7182817 , 20.085537], dtype=float32)
Args:
x: Input tensor.
Returns:
Tensor with exponential activation: `exp(x)`.
"""
return math_ops.exp(x)
@keras_export('keras.activations.hard_sigmoid')
@dispatch.add_dispatch_support
def hard_sigmoid(x):
"""Hard sigmoid activation function.
A faster approximation of the sigmoid activation.
Piecewise linear approximation of the sigmoid function.
Ref: 'https://en.wikipedia.org/wiki/Hard_sigmoid'
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.hard_sigmoid(a)
>>> b.numpy()
array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32)
Args:
x: Input tensor.
Returns:
The hard sigmoid activation, defined as:
- `if x < -2.5: return 0`
- `if x > 2.5: return 1`
- `if -2.5 <= x <= 2.5: return 0.2 * x + 0.5`
"""
return backend.hard_sigmoid(x)
@keras_export('keras.activations.linear')
@dispatch.add_dispatch_support
def linear(x):
"""Linear activation function (pass-through).
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.linear(a)
>>> b.numpy()
array([-3., -1., 0., 1., 3.], dtype=float32)
Args:
x: Input tensor.
Returns:
The input, unmodified.
"""
return x
@keras_export('keras.activations.serialize')
@dispatch.add_dispatch_support
def serialize(activation):
"""Returns the string identifier of an activation function.
Args:
activation : Function object.
Returns:
String denoting the name attribute of the input function
For example:
>>> tf.keras.activations.serialize(tf.keras.activations.tanh)
'tanh'
>>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
'sigmoid'
>>> tf.keras.activations.serialize('abcd')
Traceback (most recent call last):
...
ValueError: ('Cannot serialize', 'abcd')
Raises:
ValueError: The input function is not a valid one.
"""
if (hasattr(activation, '__name__') and
activation.__name__ in _TF_ACTIVATIONS_V2):
return _TF_ACTIVATIONS_V2[activation.__name__]
return serialize_keras_object(activation)
# Add additional globals so that deserialize can find these common activation
# functions
leaky_relu = nn.leaky_relu
log_softmax = nn.log_softmax
relu6 = nn.relu6
silu = nn.swish
@keras_export('keras.activations.deserialize')
@dispatch.add_dispatch_support
def deserialize(name, custom_objects=None):
"""Returns activation function given a string identifier.
Args:
name: The name of the activation function.
custom_objects: Optional `{function_name: function_obj}`
dictionary listing user-provided activation functions.
Returns:
Corresponding activation function.
For example:
>>> tf.keras.activations.deserialize('linear')
<function linear at 0x1239596a8>
>>> tf.keras.activations.deserialize('sigmoid')
<function sigmoid at 0x123959510>
>>> tf.keras.activations.deserialize('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Raises:
ValueError: `Unknown activation function` if the input string does not
denote any defined Tensorflow activation function.
"""
globs = globals()
# only replace missing activations
advanced_activations_globs = advanced_activations.get_globals()
for key, val in advanced_activations_globs.items():
if key not in globs:
globs[key] = val
return deserialize_keras_object(
name,
module_objects=globs,
custom_objects=custom_objects,
printable_module_name='activation function')
@keras_export('keras.activations.get')
@dispatch.add_dispatch_support
def get(identifier):
"""Returns function.
Args:
identifier: Function or string
Returns:
Function corresponding to the input string or input function.
For example:
>>> tf.keras.activations.get('softmax')
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(tf.keras.activations.softmax)
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(None)
<function linear at 0x1239596a8>
>>> tf.keras.activations.get(abs)
<built-in function abs>
>>> tf.keras.activations.get('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Raises:
ValueError: Input is an unknown function or string, i.e., the input does
not denote any defined function.
"""
if identifier is None:
return linear
if isinstance(identifier, str):
identifier = str(identifier)
return deserialize(identifier)
elif isinstance(identifier, dict):
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise TypeError(
'Could not interpret activation function identifier: {}'.format(
identifier))