Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ py_library(
srcs_version = "PY2AND3",
)

py_test(
name = "activations_test",
size = "small",
srcs = [
"activations_test.py",
],
main = "activations_test.py",
srcs_version = "PY2AND3",
deps = [
":activations",
],
)

py_test(
name = "sparsemax_test",
size = "small",
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ must:
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
decorator.
* Add a `py_test` to this sub-package's BUILD file.
* Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test serialization.

#### Documentation Requirements
* Update the table of contents in this sub-package's README.
49 changes: 49 additions & 0 deletions tensorflow_addons/activations/activations_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons import activations
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class ActivationsTest(tf.test.TestCase):

ALL_ACTIVATIONS = [
"gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink"
]

def test_serialization(self):
for name in self.ALL_ACTIVATIONS:
fn = tf.keras.activations.get(name)
ref_fn = getattr(activations, name)
self.assertEqual(fn, ref_fn)
config = tf.keras.activations.serialize(fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
for name in self.ALL_ACTIVATIONS:
layer = tf.keras.layers.Dense(
3, activation=getattr(activations, name))
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, name)
55 changes: 8 additions & 47 deletions tensorflow_addons/activations/gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,58 +19,33 @@

from absl.testing import parameterized

import math

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import gelu
from tensorflow_addons.utils import test_utils


def _ref_gelu(x, approximate=True):
x = tf.convert_to_tensor(x)
if approximate:
pi = tf.cast(math.pi, x.dtype)
coeff = tf.cast(0.044715, x.dtype)
return 0.5 * x * (
1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
else:
return 0.5 * x * (
1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))


@test_utils.run_all_in_graph_and_eager_modes
class GeluTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gelu(self, dtype):
x = np.random.rand(2, 3, 4).astype(dtype)
self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x))
self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype)
self.assertAllCloseAccordingToType(gelu(x), expected_result)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)

for approximate in [True, False]:
with self.subTest(approximate=approximate):
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_gelu(x, approximate)
y = gelu(x, approximate)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)
expected_result = tf.constant(
[-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype)
self.assertAllCloseAccordingToType(gelu(x, False), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

for approximate in [True, False]:
with self.subTest(approximate=approximate):
Expand All @@ -87,20 +62,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), gelu(x))

def test_serialization(self):
ref_fn = gelu
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=gelu)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "gelu")


if __name__ == "__main__":
tf.test.main()
50 changes: 11 additions & 39 deletions tensorflow_addons/activations/hardshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from tensorflow_addons.utils import test_utils


def _ref_hardshrink(x, lower=-1.0, upper=1.0):
x = tf.convert_to_tensor(x)
return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0)


@test_utils.run_all_in_graph_and_eager_modes
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
Expand All @@ -42,34 +37,25 @@ def test_invalid(self):
("float32", np.float32),
("float64", np.float64))
def test_hardshrink(self, dtype):
x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype)
self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x))
self.assertAllCloseAccordingToType(
hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(hardshrink(x), expected_result)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_hardshrink(x)
y = hardshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)
expected_result = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(
hardshrink(x, lower=-0.5, upper=0.5), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(
lambda x: hardshrink(x), [x])
# Hardshrink is not continuous at `lower` and `upper`.
# Avoid these two points to make gradients smooth.
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
Expand All @@ -80,20 +66,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), hardshrink(x))

def test_serialization(self):
ref_fn = hardshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=hardshrink)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "hardshrink")


if __name__ == "__main__":
tf.test.main()
13 changes: 0 additions & 13 deletions tensorflow_addons/activations/lisht_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,6 @@ def test_unknown_shape(self):
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), lisht(x))

def test_serialization(self):
config = tf.keras.activations.serialize(lisht)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, lisht)

def test_serialization_with_layers(self):
layer = tf.keras.layers.Dense(3, activation=lisht)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "lisht")


if __name__ == "__main__":
tf.test.main()
12 changes: 5 additions & 7 deletions tensorflow_addons/activations/sparsemax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@keras_utils.register_keras_custom_object
@tf.function
def sparsemax(logits, axis=-1, name=None):
def sparsemax(logits, axis=-1):
"""Sparsemax activation function [1].

For each batch `i` and class `j` we have
Expand All @@ -35,7 +35,6 @@ def sparsemax(logits, axis=-1, name=None):
Args:
logits: Input tensor.
axis: Integer, axis along which the sparsemax operation is applied.
name: A name for the operation (optional).
Returns:
Tensor, output of sparsemax transformation. Has the same type and
shape as `logits`.
Expand All @@ -50,7 +49,7 @@ def sparsemax(logits, axis=-1, name=None):
is_last_axis = (axis == -1) or (axis == rank - 1)

if is_last_axis:
output = _compute_2d_sparsemax(logits, name=name)
output = _compute_2d_sparsemax(logits)
output.set_shape(shape)
return output

Expand All @@ -64,8 +63,7 @@ def sparsemax(logits, axis=-1, name=None):

# Do the actual softmax on its last dimension.
output = _compute_2d_sparsemax(logits)
output = _swap_axis(
output, axis_norm, tf.math.subtract(rank_op, 1), name=name)
output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))

# Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
Expand All @@ -82,7 +80,7 @@ def _swap_axis(logits, dim_index, last_index, **kwargs):


@tf.function
def _compute_2d_sparsemax(logits, name=None):
def _compute_2d_sparsemax(logits):
"""Performs the sparsemax operation when axis=-1."""
shape_op = tf.shape(logits)
obs = tf.math.reduce_prod(shape_op[:-1])
Expand Down Expand Up @@ -134,5 +132,5 @@ def _compute_2d_sparsemax(logits, name=None):
logits.dtype)), p)

# Reshape back to original size
p_safe = tf.reshape(p_safe, shape_op, name=name)
p_safe = tf.reshape(p_safe, shape_op)
return p_safe
14 changes: 0 additions & 14 deletions tensorflow_addons/activations/sparsemax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,6 @@ def test_gradient_against_estimate(self, dtype=None):
lambda logits: sparsemax(logits), [z], delta=1e-6)
self.assertAllCloseAccordingToType(jacob_sym, jacob_num)

def test_serialization(self, dtype=None):
ref_fn = sparsemax
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)

def test_serialization_with_layers(self, dtype=None):
layer = tf.keras.layers.Dense(3, activation=sparsemax)
config = tf.keras.layers.serialize(layer)
deserialized_layer = tf.keras.layers.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__name__, "sparsemax")


if __name__ == '__main__':
tf.test.main()
36 changes: 13 additions & 23 deletions tensorflow_addons/activations/tanhshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,27 @@
from tensorflow_addons.utils import test_utils


def _ref_tanhshrink(x):
return x - tf.tanh(x)


@test_utils.run_all_in_graph_and_eager_modes
class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
("float64", np.float64))
def test_tanhshrink(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x))
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
expected_result = tf.constant(
[-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype)

@parameterized.named_parameters(("float16", np.float16),
("float32", np.float32),
self.assertAllCloseAccordingToType(tanhshrink(x), expected_result)

@parameterized.named_parameters(("float32", np.float32),
("float64", np.float64))
def test_gradients(self, dtype):
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
y_ref = _ref_tanhshrink(x)
y = tanhshrink(x)
grad_ref = tape.gradient(y_ref, x)
grad = tape.gradient(y, x)
self.assertAllCloseAccordingToType(grad, grad_ref)

def test_serialization(self):
ref_fn = tanhshrink
config = tf.keras.activations.serialize(ref_fn)
fn = tf.keras.activations.deserialize(config)
self.assertEqual(fn, ref_fn)
def test_theoretical_gradients(self, dtype):
# Only test theoretical gradients for float32 and float64
# because of the instability of float16 while computing jacobian
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS);

#endif // GOOGLE_CUDA

} // end namespace addons
} // namespace tensorflow
} // namespace addons
} // namespace tensorflow
Loading