From 507a3290441d51132680291fa7fc69dd3493ce77 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 29 Sep 2019 12:55:24 -0700 Subject: [PATCH 1/7] clean up activation/test --- tensorflow_addons/activations/gelu_test.py | 44 +++++-------------- .../activations/hardshrink_test.py | 41 ++++++----------- tensorflow_addons/activations/sparsemax.py | 12 +++-- .../activations/tanhshrink_test.py | 43 +++++++++--------- .../activations/cc/kernels/gelu_op.cc | 4 +- .../activations/cc/kernels/gelu_op.h | 4 +- .../activations/cc/kernels/gelu_op_gpu.cu.cc | 2 +- .../activations/cc/kernels/hardshrink_op.cc | 2 +- .../activations/cc/kernels/hardshrink_op.h | 4 +- .../cc/kernels/hardshrink_op_gpu.cu.cc | 2 +- .../custom_ops/activations/cc/ops/gelu_op.cc | 4 +- .../activations/cc/ops/hardshrink_op.cc | 2 + 12 files changed, 64 insertions(+), 100 deletions(-) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index f510715593..33254e75c2 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -27,50 +27,27 @@ from tensorflow_addons.utils import test_utils -def _ref_gelu(x, approximate=True): - x = tf.convert_to_tensor(x) - if approximate: - pi = tf.cast(math.pi, x.dtype) - coeff = tf.cast(0.044715, x.dtype) - return 0.5 * x * ( - 1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) - else: - return 0.5 * x * ( - 1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) - - @test_utils.run_all_in_graph_and_eager_modes class GeluTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(("float16", np.float16), ("float32", np.float32), ("float64", np.float64)) def test_gelu(self, dtype): - x = np.random.rand(2, 3, 4).astype(dtype) - self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x)) - self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False)) - - @parameterized.named_parameters(("float16", np.float16), - ("float32", np.float32), - ("float64", np.float64)) - def test_gradients(self, dtype): - x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant( + [-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype) + self.assertAllCloseAccordingToType(gelu(x), expected_result) - for approximate in [True, False]: - with self.subTest(approximate=approximate): - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - y_ref = _ref_gelu(x, approximate) - y = gelu(x, approximate) - grad_ref = tape.gradient(y_ref, x) - grad = tape.gradient(y, x) - self.assertAllCloseAccordingToType(grad, grad_ref) + expected_result = tf.constant( + [-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype) + self.assertAllCloseAccordingToType(gelu(x, False), expected_result) @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian - x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) for approximate in [True, False]: with self.subTest(approximate=approximate): @@ -88,10 +65,9 @@ def test_unknown_shape(self): self.assertAllClose(fn(x), gelu(x)) def test_serialization(self): - ref_fn = gelu - config = tf.keras.activations.serialize(ref_fn) + config = tf.keras.activations.serialize(gelu) fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, ref_fn) + self.assertEqual(fn, gelu) def test_serialization_with_layers(self): layer = tf.keras.layers.Dense(3, activation=gelu) diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py index a16b9be3b9..98a9af871f 100644 --- a/tensorflow_addons/activations/hardshrink_test.py +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -25,11 +25,6 @@ from tensorflow_addons.utils import test_utils -def _ref_hardshrink(x, lower=-1.0, upper=1.0): - x = tf.convert_to_tensor(x) - return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0) - - @test_utils.run_all_in_graph_and_eager_modes class HardshrinkTest(tf.test.TestCase, parameterized.TestCase): def test_invalid(self): @@ -42,34 +37,25 @@ def test_invalid(self): ("float32", np.float32), ("float64", np.float64)) def test_hardshrink(self, dtype): - x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype) - self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x)) - self.assertAllCloseAccordingToType( - hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0)) + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype) + self.assertAllCloseAccordingToType(hardshrink(x), expected_result) - @parameterized.named_parameters(("float16", np.float16), - ("float32", np.float32), - ("float64", np.float64)) - def test_gradients(self, dtype): - x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype) - - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - y_ref = _ref_hardshrink(x) - y = hardshrink(x) - grad_ref = tape.gradient(y_ref, x) - grad = tape.gradient(y, x) - self.assertAllCloseAccordingToType(grad, grad_ref) + expected_result = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + self.assertAllCloseAccordingToType( + hardshrink(x, lower=-0.5, upper=0.5), expected_result) @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) def test_theoretical_gradients(self, dtype): # Only test theoretical gradients for float32 and float64 # because of the instability of float16 while computing jacobian - x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype) - theoretical, numerical = tf.test.compute_gradient( - lambda x: hardshrink(x), [x]) + # Hardshrink is not continuous at `lower` and `upper`. + # Avoid these two points to make gradients smooth. + x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(hardshrink, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) def test_unknown_shape(self): @@ -81,10 +67,9 @@ def test_unknown_shape(self): self.assertAllClose(fn(x), hardshrink(x)) def test_serialization(self): - ref_fn = hardshrink - config = tf.keras.activations.serialize(ref_fn) + config = tf.keras.activations.serialize(hardshrink) fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, ref_fn) + self.assertEqual(fn, hardshrink) def test_serialization_with_layers(self): layer = tf.keras.layers.Dense(3, activation=hardshrink) diff --git a/tensorflow_addons/activations/sparsemax.py b/tensorflow_addons/activations/sparsemax.py index a72a5d5ba0..6f897b0ba8 100644 --- a/tensorflow_addons/activations/sparsemax.py +++ b/tensorflow_addons/activations/sparsemax.py @@ -24,7 +24,7 @@ @keras_utils.register_keras_custom_object @tf.function -def sparsemax(logits, axis=-1, name=None): +def sparsemax(logits, axis=-1): """Sparsemax activation function [1]. For each batch `i` and class `j` we have @@ -35,7 +35,6 @@ def sparsemax(logits, axis=-1, name=None): Args: logits: Input tensor. axis: Integer, axis along which the sparsemax operation is applied. - name: A name for the operation (optional). Returns: Tensor, output of sparsemax transformation. Has the same type and shape as `logits`. @@ -50,7 +49,7 @@ def sparsemax(logits, axis=-1, name=None): is_last_axis = (axis == -1) or (axis == rank - 1) if is_last_axis: - output = _compute_2d_sparsemax(logits, name=name) + output = _compute_2d_sparsemax(logits) output.set_shape(shape) return output @@ -64,8 +63,7 @@ def sparsemax(logits, axis=-1, name=None): # Do the actual softmax on its last dimension. output = _compute_2d_sparsemax(logits) - output = _swap_axis( - output, axis_norm, tf.math.subtract(rank_op, 1), name=name) + output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) # Make shape inference work since transpose may erase its static shape. output.set_shape(shape) @@ -82,7 +80,7 @@ def _swap_axis(logits, dim_index, last_index, **kwargs): @tf.function -def _compute_2d_sparsemax(logits, name=None): +def _compute_2d_sparsemax(logits): """Performs the sparsemax operation when axis=-1.""" shape_op = tf.shape(logits) obs = tf.math.reduce_prod(shape_op[:-1]) @@ -134,5 +132,5 @@ def _compute_2d_sparsemax(logits, name=None): logits.dtype)), p) # Reshape back to original size - p_safe = tf.reshape(p_safe, shape_op, name=name) + p_safe = tf.reshape(p_safe, shape_op) return p_safe diff --git a/tensorflow_addons/activations/tanhshrink_test.py b/tensorflow_addons/activations/tanhshrink_test.py index 86d72629bb..59d3e43e3f 100644 --- a/tensorflow_addons/activations/tanhshrink_test.py +++ b/tensorflow_addons/activations/tanhshrink_test.py @@ -25,37 +25,40 @@ from tensorflow_addons.utils import test_utils -def _ref_tanhshrink(x): - return x - tf.tanh(x) - - @test_utils.run_all_in_graph_and_eager_modes class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(("float16", np.float16), ("float32", np.float32), ("float64", np.float64)) def test_tanhshrink(self, dtype): - x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) - self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x)) + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + expected_result = tf.constant( + [-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype) - @parameterized.named_parameters(("float16", np.float16), - ("float32", np.float32), + self.assertAllCloseAccordingToType(tanhshrink(x), expected_result) + + @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) - def test_gradients(self, dtype): - x = tf.constant([1.0, 2.0, 3.0], dtype=dtype) - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - y_ref = _ref_tanhshrink(x) - y = tanhshrink(x) - grad_ref = tape.gradient(y_ref, x) - grad = tape.gradient(y, x) - self.assertAllCloseAccordingToType(grad, grad_ref) + def test_theoretical_gradients(self, dtype): + # Only test theoretical gradients for float32 and float64 + # because of the instability of float16 while computing jacobian + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) + + theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x]) + self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) def test_serialization(self): - ref_fn = tanhshrink - config = tf.keras.activations.serialize(ref_fn) + config = tf.keras.activations.serialize(tanhshrink) fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, ref_fn) + self.assertEqual(fn, tanhshrink) + + def test_serialization_with_layers(self): + layer = tf.keras.layers.Dense(3, activation=tanhshrink) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, "tanhshrink") if __name__ == "__main__": diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc index 52d8a66f1c..0cbc226421 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc @@ -75,5 +75,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS); #endif // GOOGLE_CUDA -} // end namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h index 174f53068a..094cf58d63 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.h @@ -59,7 +59,7 @@ struct GeluGrad { // Computes GeluGrad backprops. // // gradients: gradients backpropagated to the Gelu op. - // features: the inputs that were passed to the Gelu op. + // features: inputs that were passed to the Gelu op. // approximate: whether to enable approximation. // backprops: gradients to backpropagate to the Gelu inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, @@ -138,7 +138,7 @@ void GeluGradOp::OperateNoTemplate(OpKernelContext* context, approximate, output->flat()); } -} // end namespace addons +} // namespace addons } // namespace tensorflow #undef EIGEN_USE_THREADS diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc index c4d0fdaa48..32c6bd435d 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op_gpu.cu.cc @@ -32,7 +32,7 @@ using GPUDevice = Eigen::GpuDevice; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); -} // end namespace addons +} // namespace addons } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc index 6003491389..8563d81f64 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.cc @@ -77,5 +77,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_HARDSHRINK_GPU_KERNELS); #endif // GOOGLE_CUDA -} // end namespace addons +} // namespace addons } // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h index 1c553efa92..92313dc0eb 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op.h @@ -51,7 +51,7 @@ struct HardshrinkGrad { // Computes HardshrinkGrad backprops. // // gradients: gradients backpropagated to the Hardshink op. - // features: the inputs that were passed to the Hardshrink op. + // features: inputs that were passed to the Hardshrink op. // lower: the lower bound for setting values to zeros. // upper: the upper bound for setting values to zeros. // backprops: gradients to backpropagate to the Hardshrink inputs. @@ -136,7 +136,7 @@ void HardshrinkGradOp::OperateNoTemplate(OpKernelContext* context, upper, output->flat()); } -} // end namespace addons +} // namespace addons } // namespace tensorflow #undef EIGEN_USE_THREADS diff --git a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc index fa904723b5..9b4d2a9e83 100644 --- a/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc +++ b/tensorflow_addons/custom_ops/activations/cc/kernels/hardshrink_op_gpu.cu.cc @@ -32,7 +32,7 @@ using GPUDevice = Eigen::GpuDevice; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); -} // end namespace addons +} // namespace addons } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc index c76620802e..3af00467c1 100644 --- a/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/ops/gelu_op.cc @@ -35,5 +35,5 @@ REGISTER_OP("Addons>GeluGrad") .Attr("approximate: bool = true") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); -} // end namespace addons -} // namespace tensorflow \ No newline at end of file +} // namespace addons +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc b/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc index 5f40127801..4dd9b58e0f 100644 --- a/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc +++ b/tensorflow_addons/custom_ops/activations/cc/ops/hardshrink_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +namespace addons { REGISTER_OP("Addons>Hardshrink") .Input("features: T") @@ -36,4 +37,5 @@ REGISTER_OP("Addons>HardshrinkGrad") .Attr("upper: float = 1.0") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +} // namespace addons } // namespace tensorflow From 18f89e232cca6aff64032a5a60593ca9ad042cb8 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 29 Sep 2019 13:10:53 -0700 Subject: [PATCH 2/7] test general properties for activations --- tensorflow_addons/activations/BUILD | 13 +++++ .../activations/activations_test.py | 48 +++++++++++++++++++ tensorflow_addons/activations/gelu_test.py | 13 ----- .../activations/hardshrink_test.py | 13 ----- .../activations/sparsemax_test.py | 14 ------ .../activations/tanhshrink_test.py | 13 ----- 6 files changed, 61 insertions(+), 53 deletions(-) create mode 100644 tensorflow_addons/activations/activations_test.py diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index e5f9640bb4..663248fa5c 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -18,6 +18,19 @@ py_library( srcs_version = "PY2AND3", ) +py_test( + name = "activations_test", + size = "small", + srcs = [ + "activations_test.py", + ], + main = "activations_test.py", + srcs_version = "PY2AND3", + deps = [ + ":activations", + ], +) + py_test( name = "sparsemax_test", size = "medium", diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py new file mode 100644 index 0000000000..115aaebc33 --- /dev/null +++ b/tensorflow_addons/activations/activations_test.py @@ -0,0 +1,48 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import tensorflow as tf +from tensorflow_addons import activations +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class ActivationsTest(tf.test.TestCase): + + ALL_ACTIVATIONS = ["sparsemax", "gelu", "hardshrink", "tanhshrink"] + + def test_serialization(self): + for name in ALL_ACTIVATIONS: + fn = tf.keras.activations.get(name) + ref_fn = getattr(activations, name) + self.assertEqual(fn, ref_fn) + config = keras.activations.serialize(fn) + fn = keras.activations.deserialize(config) + self.assertEqual(fn, ref_fn) + + def test_serialization_with_layers(self): + for name in ALL_ACTIVATIONS: + layer = tf.keras.layers.Dense( + 3, activation=getattr(activations, name)) + config = tf.keras.layers.serialize(layer) + deserialized_layer = tf.keras.layers.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__name__, name) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 33254e75c2..79d168f565 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -64,19 +64,6 @@ def test_unknown_shape(self): x = tf.ones(shape=shape, dtype=tf.float32) self.assertAllClose(fn(x), gelu(x)) - def test_serialization(self): - config = tf.keras.activations.serialize(gelu) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, gelu) - - def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=gelu) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "gelu") - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py index 98a9af871f..87c4e84aa2 100644 --- a/tensorflow_addons/activations/hardshrink_test.py +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -66,19 +66,6 @@ def test_unknown_shape(self): x = tf.ones(shape=shape, dtype=tf.float32) self.assertAllClose(fn(x), hardshrink(x)) - def test_serialization(self): - config = tf.keras.activations.serialize(hardshrink) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, hardshrink) - - def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=hardshrink) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "hardshrink") - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_addons/activations/sparsemax_test.py b/tensorflow_addons/activations/sparsemax_test.py index 62e03f0184..3abe28fbb1 100644 --- a/tensorflow_addons/activations/sparsemax_test.py +++ b/tensorflow_addons/activations/sparsemax_test.py @@ -274,20 +274,6 @@ def test_gradient_against_estimate(self, dtype=None): lambda logits: sparsemax(logits), [z], delta=1e-6) self.assertAllCloseAccordingToType(jacob_sym, jacob_num) - def test_serialization(self, dtype=None): - ref_fn = sparsemax - config = tf.keras.activations.serialize(ref_fn) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, ref_fn) - - def test_serialization_with_layers(self, dtype=None): - layer = tf.keras.layers.Dense(3, activation=sparsemax) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "sparsemax") - if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_addons/activations/tanhshrink_test.py b/tensorflow_addons/activations/tanhshrink_test.py index 59d3e43e3f..139b099be1 100644 --- a/tensorflow_addons/activations/tanhshrink_test.py +++ b/tensorflow_addons/activations/tanhshrink_test.py @@ -47,19 +47,6 @@ def test_theoretical_gradients(self, dtype): theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) - def test_serialization(self): - config = tf.keras.activations.serialize(tanhshrink) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, tanhshrink) - - def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=tanhshrink) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "tanhshrink") - if __name__ == "__main__": tf.test.main() From fad1dd4434e9874811a7a6465a50890f281eb1c7 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 30 Sep 2019 09:21:18 -0700 Subject: [PATCH 3/7] add lisht --- tensorflow_addons/activations/activations_test.py | 3 ++- tensorflow_addons/activations/lisht_test.py | 13 ------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index 115aaebc33..ff85a70606 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -26,7 +26,8 @@ @test_utils.run_all_in_graph_and_eager_modes class ActivationsTest(tf.test.TestCase): - ALL_ACTIVATIONS = ["sparsemax", "gelu", "hardshrink", "tanhshrink"] + ALL_ACTIVATIONS = ["gelu", "hardshrink", "lisht", "sparsemax", + "tanhshrink"] def test_serialization(self): for name in ALL_ACTIVATIONS: diff --git a/tensorflow_addons/activations/lisht_test.py b/tensorflow_addons/activations/lisht_test.py index b4e7fd2dfc..c14417ef41 100644 --- a/tensorflow_addons/activations/lisht_test.py +++ b/tensorflow_addons/activations/lisht_test.py @@ -55,19 +55,6 @@ def test_unknown_shape(self): x = tf.ones(shape=shape, dtype=tf.float32) self.assertAllClose(fn(x), lisht(x)) - def test_serialization(self): - config = tf.keras.activations.serialize(lisht) - fn = tf.keras.activations.deserialize(config) - self.assertEqual(fn, lisht) - - def test_serialization_with_layers(self): - layer = tf.keras.layers.Dense(3, activation=lisht) - config = tf.keras.layers.serialize(layer) - deserialized_layer = tf.keras.layers.deserialize(config) - self.assertEqual(deserialized_layer.__class__.__name__, - layer.__class__.__name__) - self.assertEqual(deserialized_layer.activation.__name__, "lisht") - if __name__ == "__main__": tf.test.main() From ea37e7c06de0bf4dfab109af01c079da2abf840f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 30 Sep 2019 09:24:03 -0700 Subject: [PATCH 4/7] remove unused import --- tensorflow_addons/activations/gelu_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow_addons/activations/gelu_test.py b/tensorflow_addons/activations/gelu_test.py index 79d168f565..2aac0ea281 100644 --- a/tensorflow_addons/activations/gelu_test.py +++ b/tensorflow_addons/activations/gelu_test.py @@ -19,8 +19,6 @@ from absl.testing import parameterized -import math - import numpy as np import tensorflow as tf from tensorflow_addons.activations import gelu From dbdc28d0a9b8f632f85353ddc3476d5b12cc3b82 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 30 Sep 2019 09:24:26 -0700 Subject: [PATCH 5/7] sanity check --- tensorflow_addons/activations/activations_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/activations/activations_test.py b/tensorflow_addons/activations/activations_test.py index ff85a70606..31a4b82196 100644 --- a/tensorflow_addons/activations/activations_test.py +++ b/tensorflow_addons/activations/activations_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function - import tensorflow as tf from tensorflow_addons import activations from tensorflow_addons.utils import test_utils @@ -26,20 +25,21 @@ @test_utils.run_all_in_graph_and_eager_modes class ActivationsTest(tf.test.TestCase): - ALL_ACTIVATIONS = ["gelu", "hardshrink", "lisht", "sparsemax", - "tanhshrink"] + ALL_ACTIVATIONS = [ + "gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink" + ] def test_serialization(self): - for name in ALL_ACTIVATIONS: + for name in self.ALL_ACTIVATIONS: fn = tf.keras.activations.get(name) ref_fn = getattr(activations, name) self.assertEqual(fn, ref_fn) - config = keras.activations.serialize(fn) - fn = keras.activations.deserialize(config) + config = tf.keras.activations.serialize(fn) + fn = tf.keras.activations.deserialize(config) self.assertEqual(fn, ref_fn) def test_serialization_with_layers(self): - for name in ALL_ACTIVATIONS: + for name in self.ALL_ACTIVATIONS: layer = tf.keras.layers.Dense( 3, activation=getattr(activations, name)) config = tf.keras.layers.serialize(layer) From 7ac2d818b79b65b936d18422c608479e19f9abd5 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 2 Oct 2019 20:32:29 -0700 Subject: [PATCH 6/7] add requirements of serialization tests --- tensorflow_addons/activations/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index ede5eb30fb..ca2a4947ee 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -35,6 +35,7 @@ must: or `run_all_in_graph_and_eager_modes` (for TestCase subclass) decorator. * Add a `py_test` to this sub-package's BUILD file. + * Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test searialization. #### Documentation Requirements * Update the table of contents in this sub-package's README. From 3635401b851b5ec327dad1ab711737082c56b53e Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 2 Oct 2019 21:56:13 -0700 Subject: [PATCH 7/7] fix typo --- tensorflow_addons/activations/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index ca2a4947ee..c548a32923 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -35,7 +35,7 @@ must: or `run_all_in_graph_and_eager_modes` (for TestCase subclass) decorator. * Add a `py_test` to this sub-package's BUILD file. - * Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test searialization. + * Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test serialization. #### Documentation Requirements * Update the table of contents in this sub-package's README.