diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 0446e823d95f8e..4d34038f7fa867 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -25,6 +25,7 @@ py_library( "python/training/lars_optimizer.py", "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", + "python/training/sparse_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index e8fc52342ceabb..c1886056f10ac6 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -29,6 +29,7 @@ from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.sparse_adam_optimizer import * from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * @@ -55,6 +56,7 @@ 'LARSOptimizer', 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', + 'SparseAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', 'MomentumWOptimizer', diff --git a/tensorflow/contrib/opt/python/training/sparse_adam_optimizer.py b/tensorflow/contrib/opt/python/training/sparse_adam_optimizer.py new file mode 100644 index 00000000000000..1fd097952824e4 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/sparse_adam_optimizer.py @@ -0,0 +1,236 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Variant of the Adam optimizer that handles sparse updates more efficiently and +convergence is close to Adam. + +Compared with the original Adam optimizer, the one in this file can provide a +large improvement in model training throughput for some applications. +The semantics of Sparse Adam is close to the original Adam algorithm, and the +convergence is also close to Adam. + +A detailed description of SparseAdam. +- maintain a timestamp variable (pre_step) for each weight +- count the skipped steps for each weight +- skipped_step = global_step - pre_step +- adapt the Adam learning rate based on skipped_step +- For each time step, SparseAdam updates weight first, then update momentum (m) +and momentum (v) + +lr = extlearningrate * sqrt(1 - beta2 ** pre_step) / (1 - beta1 ** pre_step) * + (1 - beta1 ** skipped_step) / (1 - beta1) +variable = variable - lr * m / sqrt(v + epsilon) +m = m * beta1 ** skipped_step + (1 - beta1) * gradient +v = v * beta2 ** skipped_step + (1 - beta2) * gradient ** 2 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam +from tensorflow.python.framework import ops +from tensorflow.python.framework import constant_op + + +class SparseAdamOptimizer(adam.AdamOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + The same as Lazy Adam, this class provides lazier handling of gradient + updates for sparse variables. It only updates moving-average accumulators + for sparse variable indices that appear in the current batch, rather than + updating the accumulators for all indices. Compared with the original + Adam optimizer, it can provide large improvements in model training + throughput for some applications. Compared with Lazy Adam, the sementics of + Spare Adam is close to original Adam optimizer, and the convergence is also + close to original Adam. + """ + + def _create_slots(self, var_list): + #Create the beta1 and beta2 accumulators on the same device as the first + #variable.Sort the var_list to make sure this device is consistent across + #workers(these need to go on the same PS, otherwise some updates are + #silently ignored). + first_var = min(var_list, key=lambda x: x.name) + self._create_non_slot_variable( + initial_value=self._beta1, name="beta1_power", colocate_with=first_var) + self._create_non_slot_variable( + initial_value=self._beta2, name="beta2_power", colocate_with=first_var) + self._create_non_slot_variable( + initial_value=1.0, name="global_step", colocate_with=first_var) + + #Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + self._get_or_make_slot( + v, + constant_op.constant( + 1.0, dtype=v.dtype.base_dtype, shape=v.get_shape()), "pre_step", + self._name) + + def _get_step_accumulators(self): + with ops.init_scope(): + if context.executing_eagerly(): + graph = None + else: + graph = ops.get_default_graph() + return self._get_non_slot_variable("global_step", graph=graph) + + def _finish(self, update_ops, name_scope): + #Update the power accumulators. + with ops.control_dependencies(update_ops): + beta1_power, beta2_power = self._get_beta_accumulators() + global_step = self._get_step_accumulators() + with ops.colocate_with(beta1_power): + update_beta1 = beta1_power.assign( + beta1_power * self._beta1_t, use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * self._beta2_t, use_locking=self._use_locking) + update_step = global_step.assign( + global_step + 1, use_locking=self._use_locking) + return control_flow_ops.group( + *update_ops + [update_beta1, update_beta2, update_step], + name=name_scope) + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + global_step = self._get_step_accumulators() + global_step = math_ops.cast(global_step, var.dtype.base_dtype) + pre_step = self.get_slot(var, "pre_step") + + indices = grad.indices + pre_step_slice = array_ops.gather(pre_step, indices) + skipped_steps = global_step - pre_step_slice + + m = self.get_slot(var, "m") + m_slice = array_ops.gather(m, indices) + v = self.get_slot(var, "v") + v_slice = array_ops.gather(v, indices) + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) / + # (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) / + # (1 - beta1)\\) + lr = ((lr_t * math_ops.sqrt( + 1 - math_ops.exp(pre_step_slice * math_ops.log(beta2_t))) / + (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) * + (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) / + (1 - beta1_t)) + # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\) + var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t) + var_update_op = state_ops.scatter_sub( + var, indices, var_slice, use_locking=self._use_locking) + + with ops.control_dependencies([var_update_op]): + # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\) + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + m_t_slice = ( + math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice + + (1 - beta1_t) * grad) + m_update_op = state_ops.scatter_update( + m, indices, m_t_slice, use_locking=self._use_locking) + + # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\) + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + v_t_slice = ( + math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = state_ops.scatter_update( + v, indices, v_t_slice, use_locking=self._use_locking) + + with ops.control_dependencies([m_update_op, v_update_op]): + pre_step_update_op = state_ops.scatter_update( + pre_step, indices, global_step, use_locking=self._use_locking) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op, + pre_step_update_op) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + global_step = self._get_step_accumulators() + global_step = math_ops.cast(global_step, var.dtype.base_dtype) + pre_step = self.get_slot(var, "pre_step") + + pre_step_slice = array_ops.gather(pre_step, indices) + skipped_steps = global_step - pre_step_slice + + m = self.get_slot(var, "m") + m_slice = array_ops.gather(m, indices) + v = self.get_slot(var, "v") + v_slice = array_ops.gather(v, indices) + + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) / + # (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) / + # (1 - beta1)\\) + lr = ((lr_t * math_ops.sqrt( + 1 - math_ops.exp(pre_step_slice * math_ops.log(beta2_t))) / + (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) * + (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) / + (1 - beta1_t)) + # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\) + var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub( + var.handle, indices, var_slice) + + with ops.control_dependencies([var_update_op]): + # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\) + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + m_t_slice = ( + math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice + + (1 - beta1_t) * grad) + m_update_op = resource_variable_ops.resource_scatter_update( + m.handle, indices, m_t_slice) + + # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\) + # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to + # replace math_ops.pow(b, a) + v_t_slice = ( + math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update( + v.handle, indices, v_t_slice) + + with ops.control_dependencies([m_update_op, v_update_op]): + pre_step_update_op = resource_variable_ops.resource_scatter_update( + pre_step.handle, indices, global_step) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op, + pre_step_update_op) diff --git a/tensorflow/contrib/opt/python/training/sparse_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/sparse_adam_optimizer_test.py new file mode 100644 index 00000000000000..b6bc3b526ce557 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/sparse_adam_optimizer_test.py @@ -0,0 +1,391 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SparseAdamOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import sparse_adam_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +# This test file can only test basic case (skipped step = 1) of SparseAdam. +# When skipped step is 1, the formula of SparseAdam is the same as Adam. +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = (alpha * math_ops.sqrt(1 - math_ops.exp(t * math_ops.log(beta2))) / + (1 - math_ops.exp(math_ops.log(beta1) * t))) + param_t = param - alpha_t * m / (np.sqrt(v) + epsilon) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + return param_t, m_t, v_t + + +class AdamOptimizerTest(test.TestCase, parameterized.TestCase): + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + #for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = sparse_adam_optimizer.SparseAdamOptimizer(epsilon=1e-7) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllClose([3.0, 4.0], var1.eval(), rtol=1e-3, atol=1e-3) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType( + 0.9**t, beta1_power.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + 0.999**t, beta2_power.eval(), rtol=1e-3, atol=1e-3) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + var1_np, var1.eval(), rtol=1e-3, atol=1e-3) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = sparse_adam_optimizer.SparseAdamOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_update_var = variables.Variable([[1.0], [2.0]], + dtype=dtype) + aggregated_update_var = variables.Variable([[1.0], [2.0]], + dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant([0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant([0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + repeated_update_opt = sparse_adam_optimizer.SparseAdamOptimizer() + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update_opt = sparse_adam_optimizer.SparseAdamOptimizer() + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + variables.global_variables_initializer().run() + self.assertAllClose( + aggregated_update_var.eval(), + repeated_index_update_var.eval(), + rtol=1e-3, + atol=1e-3) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose( + aggregated_update_var.eval(), + repeated_index_update_var.eval(), + rtol=1e-3, + atol=1e-3) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = sparse_adam_optimizer.SparseAdamOptimizer( + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], + self.evaluate(var0), + rtol=1e-3, + atol=1e-3) + self.assertAllClose([3.0, 4.0], + self.evaluate(var1), + rtol=1e-3, + atol=1e-3) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power), rtol=1e-3, atol=1e-3) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-3, atol=1e-3) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = sparse_adam_optimizer.SparseAdamOptimizer( + constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllClose([3.0, 4.0], var1.eval(), rtol=1e-3, atol=1e-3) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType( + 0.9**t, beta1_power.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + 0.999**t, beta2_power.eval(), rtol=1e-3, atol=1e-3) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + var1_np, var1.eval(), rtol=1e-3, atol=1e-3) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = sparse_adam_optimizer.SparseAdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllClose([3.0, 4.0], var1.eval(), rtol=1e-3, atol=1e-3) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType( + 0.9**t, beta1_power.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + 0.999**t, beta2_power.eval(), rtol=1e-3, atol=1e-3) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, var0.eval(), rtol=1e-3, atol=1e-3) + self.assertAllCloseAccordingToType( + var1_np, var1.eval(), rtol=1e-3, atol=1e-3) + + def testTwoSessions(self): + optimizer = sparse_adam_optimizer.SparseAdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = sparse_adam_optimizer.SparseAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(9, len(set(opt.variables()))) + + +if __name__ == "__main__": + test.main()