Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a collection of keras layers used for the gan architectures.
PiperOrigin-RevId: 385889780
- Loading branch information
Showing
4 changed files
with
851 additions
and
0 deletions.
There are no files selected for viewing
97 changes: 97 additions & 0 deletions
97
tensorflow_graphics/projects/gan/exponential_moving_average.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2020 The TensorFlow Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Implements an ExponentialMovingAverage class that is checkpointable.""" | ||
|
||
from typing import Sequence | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class ExponentialMovingAverage(tf.Module): | ||
"""Exponential moving average. | ||
This class is a checkpointable implementation of a subset of the functionality | ||
provided by tf.train.ExponentialMovingAverage. The tf version is not | ||
checkpointable due to use of tf.Variable.ref() to associate tf.Variables | ||
objects to their corresponding averages | ||
(cf. https://github.com/tensorflow/tensorflow/issues/38452). This version uses | ||
the order of the tf.Variable objects in a sequence to associate the variables | ||
with their averages. | ||
Note: This class offers less functionality than the tensorflow version and it | ||
is only implemented for replica context. | ||
Attributes: | ||
averaged_variables: A sequence of tf.Variables that stores the averages for | ||
the variables. They are associated to the new values that are provided to | ||
ExponentialMovingAverage.apply() by the order in the sequence. If None a | ||
call to ExponentialMovingAverage.apply() initializes the variable before | ||
applying the update. | ||
""" | ||
|
||
def __init__(self, decay: float = 0.999): | ||
"""Initializes exponential moving average. | ||
Args: | ||
decay: The decay rate of the exponential moving average. | ||
""" | ||
self.averaged_variables: Sequence[tf.Variable] = None | ||
self._decay = decay | ||
|
||
def _ema_assign_fn(self, variable: tf.Variable, value: tf.Tensor): | ||
"""Updates the exponential moving average for a single variable.""" | ||
return variable.assign(self._decay * variable + (1.0 - self._decay) * value) | ||
|
||
def _apply_values(self, variables: Sequence[tf.Variable]): | ||
"""Applies the new values to the exponential moving averages.""" | ||
|
||
def merge_fn(strategy: tf.distribute.Strategy, variable: tf.Variable, | ||
value: tf.Tensor): | ||
value = strategy.extended.reduce_to(tf.distribute.ReduceOp.MEAN, value, | ||
variable) | ||
strategy.extended.update(variable, self._ema_assign_fn, args=(value,)) | ||
|
||
replica_context = tf.distribute.get_replica_context() | ||
|
||
if replica_context: | ||
for variable_ema, variable in zip(self.averaged_variables, variables): | ||
replica_context.merge_call(merge_fn, args=(variable_ema, variable)) | ||
else: | ||
raise NotImplementedError( | ||
'Cross-replica context version not implemented.') | ||
|
||
def apply(self, variables: Sequence[tf.Variable]): | ||
"""Applies new values to the averages. | ||
This function is called to update the averages with new values. If the | ||
variables for the averages have not been created before this function | ||
creates new variables for the averages before the update. | ||
Args: | ||
variables: The variables storing the values to apply to the averages. The | ||
sequence is assumed to have the same order of the variables as the | ||
averages stored in self.averaged_variables. If self.averaged_variables | ||
is None it gets initialized with a new sequence of variables with the | ||
values of the provided variables as initial value. | ||
""" | ||
if self.averaged_variables is None: | ||
with tf.init_scope(): | ||
strategy = tf.distribute.get_strategy() | ||
self.averaged_variables = [] | ||
|
||
for variable in variables: | ||
with strategy.extended.colocate_vars_with(variable): | ||
self.averaged_variables.append( | ||
tf.Variable(initial_value=variable.read_value())) | ||
self._apply_values(variables) |
69 changes: 69 additions & 0 deletions
69
tensorflow_graphics/projects/gan/exponential_moving_average_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2020 The TensorFlow Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for gan.exponential_moving_average.""" | ||
|
||
import tensorflow as tf | ||
|
||
from tensorflow_graphics.projects.gan import exponential_moving_average | ||
|
||
|
||
class ExponentialMovingAverageTest(tf.test.TestCase): | ||
|
||
def test_decay_one_values_are_from_initialization(self): | ||
ema = exponential_moving_average.ExponentialMovingAverage(decay=1.0) | ||
initial_value = 2.0 | ||
variable = tf.Variable(initial_value) | ||
|
||
ema.apply((variable,)) | ||
variable.assign(3.0) | ||
ema.apply((variable,)) | ||
|
||
self.assertAllClose(ema.averaged_variables[0], initial_value) | ||
|
||
def test_decay_zero_returns_last_value(self): | ||
ema = exponential_moving_average.ExponentialMovingAverage(decay=0.0) | ||
final_value = 3.0 | ||
variable = tf.Variable(2.0) | ||
|
||
ema.apply((variable,)) | ||
variable.assign(final_value) | ||
ema.apply((variable,)) | ||
|
||
self.assertAllClose(ema.averaged_variables[0], final_value) | ||
|
||
def test_cross_replica_context_raises_error(self): | ||
ema = exponential_moving_average.ExponentialMovingAverage(decay=0.0) | ||
|
||
with self.assertRaisesRegex( | ||
NotImplementedError, 'Cross-replica context version not implemented.'): | ||
with tf.distribute.MirroredStrategy().scope(): | ||
variable = tf.Variable(2.0) | ||
ema.apply((variable,)) | ||
|
||
def test_mirrored_strategy_replica_context_runs(self): | ||
ema = exponential_moving_average.ExponentialMovingAverage(decay=0.5) | ||
strategy = tf.distribute.MirroredStrategy() | ||
|
||
def apply_to_ema(variable): | ||
ema.apply((variable,)) | ||
|
||
with strategy.scope(): | ||
variable = tf.Variable(2.0) | ||
strategy.run(apply_to_ema, (variable,)) | ||
|
||
self.assertAllClose(ema.averaged_variables[0], variable.read_value()) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Oops, something went wrong.