Skip to content

Commit

Permalink
[WIP] Add Adadelta optimizer
Browse files Browse the repository at this point in the history
Missing:
- testing
- sparse support
- GPU support
  • Loading branch information
Mistobaan committed Jan 7, 2016
1 parent 5d824f1 commit 7a262ee
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 1 deletion.
142 changes: 142 additions & 0 deletions tensorflow/core/kernels/training_ops.cc
Expand Up @@ -41,6 +41,28 @@ struct ApplyGradientDescent<CPUDevice, T> {
}
};

template <typename T>
struct ApplyAdadelta<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::Flat update_accum,
typename TTypes<T>::ConstScalar decay_rate,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
if (DoInline(var.size())) {
accum += accum * decay_rate() + grad.square() * (1 - decay_rate());
auto update = (update_accum + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad;
update_accum = update_accum * decay_rate() + update.square() * (1 - decay_rate());
var -= update;
} else {
accum.device(d) += accum * decay_rate() + grad.square() * (1 - decay_rate()) ;
auto update = update_accum * (accum + epsilon()).rsqrt() * grad;
update_accum.device(d) = update_accum * decay_rate() + update.square() * (1 - decay_rate());
var.device(d) -= update;
}
}
};

template <typename T>
struct ApplyAdagrad<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
Expand Down Expand Up @@ -202,6 +224,126 @@ REGISTER_KERNELS(GPU, double);
#endif
#undef REGISTER_KERNELS

template <typename Device, typename T>
class ApplyAdadeltaOp : public OpKernel {
public:
explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
}

void Compute(OpKernelContext* ctx) override {
if (use_exclusive_lock_) {
mutex_lock l1(*ctx->input_ref_mutex(0));
// Don't try to acquire a lock on the second ref as they share the same
// mutex.
//
// mutex_lock l2(*ctx->input_ref_mutex(1));
DoValidate(ctx);
if (!ctx->status().ok()) return;
DoCompute(ctx);
} else {
DoValidate(ctx);
if (!ctx->status().ok()) return;
DoCompute(ctx);
}
ctx->forward_ref_input_to_ref_output(0, 0);
}

private:
bool use_exclusive_lock_;

void DoValidate(OpKernelContext* ctx) {
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
Tensor update_accum = ctx->mutable_input(2, use_exclusive_lock_);

OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(0)));
OP_REQUIRES(
ctx, accum.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(1)));
OP_REQUIRES(
ctx, update_accum.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(2)));

const Tensor& decay_rate = ctx->input(3);
const Tensor& epsilon = ctx->input(4);
const Tensor& grad = ctx->input(5);

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(decay_rate.shape()),
errors::InvalidArgument("decay_rate is not a scalar: ",
decay_rate.shape().DebugString()));

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon.shape().DebugString()));

OP_REQUIRES(
ctx, var.shape().IsSameSize(accum.shape()),
errors::InvalidArgument("var and accum do not have the same shape",
var.shape().DebugString(), " ",
accum.shape().DebugString()));
OP_REQUIRES(
ctx, var.shape().IsSameSize(grad.shape()),
errors::InvalidArgument("var and grad do not have the same shape",
var.shape().DebugString(), " ",
grad.shape().DebugString()));
}

void DoCompute(OpKernelContext* ctx) {
const Device& device = ctx->template eigen_device<Device>();
Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
Tensor update_accum = ctx->mutable_input(2, use_exclusive_lock_);

const Tensor& decay_rate = ctx->input(3);
const Tensor& epsilon = ctx->input(4);
const Tensor& grad = ctx->input(5);

functor::ApplyAdadelta<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
update_accum.flat<T>(), decay_rate.scalar<T>(),
epsilon.scalar<T>(), grad.flat<T>());
}
};

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdadeltaOp<D##Device, T>);

REGISTER_KERNELS(CPU, float);
REGISTER_KERNELS(CPU, double);

#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void ApplyAdadelta<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat accum, \
typename TTypes<T>::Flat update_accum, \
typename TTypes<T>::ConstScalar decay_rate, \
typename TTypes<T>::ConstScalar epsilon, \
typename TTypes<T>::ConstFlat grad); \
extern template struct ApplyAdadelta<GPUDevice, T>;
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor

REGISTER_KERNELS(GPU, float);
REGISTER_KERNELS(GPU, double);
#endif
#undef REGISTER_KERNELS

template <typename Device, typename T>
class ApplyAdagradOp : public OpKernel {
public:
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/training_ops.h
Expand Up @@ -33,6 +33,16 @@ struct ApplyGradientDescent {
typename TTypes<T>::ConstFlat delta);
};

template <typename Device, typename T>
struct ApplyAdadelta {
void operator()(const Device& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::Flat update_accum,
typename TTypes<T>::ConstScalar decay_rate,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad);
};

template <typename Device, typename T>
struct ApplyAdagrad {
void operator()(const Device& d, typename TTypes<T>::Flat var,
Expand Down
29 changes: 29 additions & 0 deletions tensorflow/core/ops/training_ops.cc
Expand Up @@ -35,6 +35,35 @@ use_locking: If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");

REGISTER_OP("ApplyAdadelta")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("update_accum: Ref(T)")
.Input("decay_rate: T")
.Input("epsilon: T")
.Input("grad: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Doc(R"doc(
Update '*var' according to the adadelta scheme.
accum += decay_rate() * accum + (1 - decay_rate()) * grad.square();
update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad;
update_accum = decay_rate() * update_accum + (1 - decay_rate()) * update.square();
var -= update;
var: Should be from a Variable().
accum: Should be from a Variable().
update_accum: Should be from a Variable().
decay_rate: Scaling factor. Must be a scalar.
epsilon: Constant factor. Must be a scalar.
grad: The gradient.
out: Same as "var".
use_locking: If True, updating of the var, accum and update_accum tensors will be protected by
a lock; otherwise the behavior is undefined, but may exhibit less contention.
)doc");

REGISTER_OP("ApplyAdagrad")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
Expand Down
84 changes: 84 additions & 0 deletions tensorflow/python/training/adadelta.py
@@ -0,0 +1,84 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Adadelta for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops


class AdadeltaOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adadelta algorithm.
@@__init__
"""

def __init__(self, decay_rate=0.001, epsilon=1e-8,
use_locking=False, name="Adadelta"):
"""Construct a new Adadelta optimizer.
Implementation is based on http://arxiv.org/abs/1212.5701
Args:
decay_rate: A `Tensor` or a floating point value. The decay_rate.
epsilon: A `Tensor` or a floating point value. A constant epsilon used
to better conditioning the grad update.
use_locking: If `True` use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
"""
super(AdadeltaOptimizer, self).__init__(use_locking, name)
self._decay_rate = decay_rate
self._epsilon = epsilon

# Tensor versions of the constructor arguments, created in _prepare().
self._decay_rate_t = None
self.epsilon_t = None

def _create_slots(self, var_list):
for v in var_list:
self._zeros_slot(v, "decay_rate", self._name)
self._zeros_slot(v, "epsilon", self._name)
self._zeros_slot(v, "accum", self._name)
self._zeros_slot(v, "update_accum", self._name)

def _prepare(self):
self._decay_rate_t = ops.convert_to_tensor(self._decay_rate,
name="decay_rate")
self._epsilon_t = ops.convert_to_tensor(self._epsilon,
name="epsilon")

def _apply_dense(self, grad, var):
decay_rate = self.get_slot(var, "decay_rate")
epsilon = self.get_slot(var, "epsilon")
accum = self.get_slot(var, "accum")
update_accum = self.get_slot(var, "update_accum")

return training_ops.apply_adadelta(
var, accum, update_accum,
self._decay_rate_t, self._epsilon_t, grad,
use_locking=self._use_locking).op

# def _apply_sparse(self, grad, var):
# mom = self.get_slot(var, "adadelta")
# return training_ops.sparse_apply_adadelta(
# var, mom,
# self._learning_rate_tensor, grad.values, grad.indices,
# self._adadelta_tensor, use_locking=self._use_locking).op
83 changes: 83 additions & 0 deletions tensorflow/python/training/adadelta_test.py
@@ -0,0 +1,83 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for Momentum."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.python.platform

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf


class AdadeltaOptimizerTest(tf.test.TestCase):

def testBasic(self):
with self.test_session():
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([0.01, 0.01])
adadelta_opt = tf.train.AdadeltaOptimizer()
adadelta_update = adadelta_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
tf.initialize_all_variables().run()

# Check we have slots
self.assertEqual(["accum", "decay_rate", "epsilon", "update_accum"], adadelta_opt.get_slot_names())
slot0 = adadelta_opt.get_slot(var0, "accum")
self.assertEquals(slot0.get_shape(), var0.get_shape())
self.assertFalse(slot0 in tf.trainable_variables())
slot1 = adadelta_opt.get_slot(var1, "accum")
self.assertEquals(slot1.get_shape(), var1.get_shape())
self.assertFalse(slot1 in tf.trainable_variables())

# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())

adadelta_update.run()

# Check that the accumulators have been updated.
self.assertAllClose(np.array([0.1, 0.1]), slot0.eval())
self.assertAllClose(np.array([0.01, 0.01]), slot1.eval())

# Check that the parameters have been updated.
self.assertAllClose(np.array([1.0 - (0.1 * 2.0),
2.0 - (0.1 * 2.0)]),
var0.eval())
self.assertAllClose(np.array([3.0 - (0.01 * 2.0),
4.0 - (0.01 * 2.0)]),
var1.eval())
# Step 2: the momentum accumulators contain the previous update.
adadelta_update.run()
# Check that the momentum accumulators have been updated.
self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
slot0.eval())
self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
slot1.eval())
# Check that the parameters have been updated.
self.assertAllClose(
np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
var0.eval())
self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
var1.eval())

if __name__ == "__main__":
tf.test.main()
2 changes: 1 addition & 1 deletion tensorflow/python/training/momentum_test.py
Expand Up @@ -165,7 +165,7 @@ def _dbParamsMom01(self):
"""Return dist-belief momentum values.
Return values been generated from the dist-belief momentum unittest,
running with a learning rate of 0.1 and a momemntum of 0.1.
running with a learning rate of 0.1 and a momentum of 0.1.
These values record how a parameter vector of size 10, initialized with 0.0,
gets updated with 10 consecutive momentum steps. It uses random gradients.
Expand Down

0 comments on commit 7a262ee

Please sign in to comment.