From 314a5fe6bef630b9df502eb7c79470c275b478fc Mon Sep 17 00:00:00 2001 From: rushabh-v Date: Sun, 30 Aug 2020 00:47:25 +0530 Subject: [PATCH 1/2] fix typing of num_updates in moving_average --- tensorflow_addons/optimizers/moving_average.py | 12 ++++++++++-- .../optimizers/tests/moving_average_test.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index db14979c12..76f21b4091 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -18,7 +18,7 @@ from tensorflow_addons.optimizers import AveragedOptimizerWrapper from tensorflow_addons.utils import types -from typing import Optional +from typing import Union from typeguard import typechecked @@ -47,7 +47,7 @@ def __init__( optimizer: types.Optimizer, sequential_update: bool = True, average_decay: types.FloatTensorLike = 0.99, - num_updates: Optional[str] = None, + num_updates: Union[None, int, tf.Variable] = None, start_step: int = 0, dynamic_decay: bool = False, name: str = "MovingAverage", @@ -82,6 +82,14 @@ def __init__( super().__init__(optimizer, sequential_update, name, **kwargs) self._num_updates = num_updates if self._num_updates is not None: + if isinstance(self._num_updates, tf.Variable): + tf.debugging.assert_integer( + self._num_updates, + ( + 'type of argument "num_updates" must be ' + + f"int; got {self._num_updates.dtype} instead" + ), + ) num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates") average_decay = tf.minimum( average_decay, (1.0 + num_updates) / (10.0 + num_updates) diff --git a/tensorflow_addons/optimizers/tests/moving_average_test.py b/tensorflow_addons/optimizers/tests/moving_average_test.py index 4e4b90983f..dfd541a3cc 100644 --- a/tensorflow_addons/optimizers/tests/moving_average_test.py +++ b/tensorflow_addons/optimizers/tests/moving_average_test.py @@ -68,6 +68,19 @@ def test_opt_failure(): MovingAverage(base_opt, 0.5) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_num_updates_valid(): + for num_updates in [1, tf.Variable(1)]: + MovingAverage("sgd", num_updates=num_updates) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_num_updates_invalid(): + for num_updates in [1.0, tf.Variable(1.0), "a"]: + with pytest.raises(TypeError): + MovingAverage("sgd", num_updates=num_updates) + + @pytest.mark.usefixtures("maybe_run_functions_eagerly") def test_model_weights_update(): grad = tf.Variable([[0.1]]) From d2f59772ae8c0ab1daac23412a017c5fc6619e35 Mon Sep 17 00:00:00 2001 From: rushabh-v Date: Sun, 30 Aug 2020 09:14:37 +0530 Subject: [PATCH 2/2] use .format instead of fstring --- tensorflow_addons/optimizers/moving_average.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/moving_average.py b/tensorflow_addons/optimizers/moving_average.py index 76f21b4091..e40a9cafad 100644 --- a/tensorflow_addons/optimizers/moving_average.py +++ b/tensorflow_addons/optimizers/moving_average.py @@ -87,7 +87,7 @@ def __init__( self._num_updates, ( 'type of argument "num_updates" must be ' - + f"int; got {self._num_updates.dtype} instead" + "int; got {} instead".format(self._num_updates.dtype) ), ) num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates")