Skip to content

Commit

Permalink
Adds Adagrad and Adadelta optimizers and associated tests. (keras-tea…
Browse files Browse the repository at this point in the history
…m#72)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Sync with main (keras-team#56)

* Minor touch ups

* Fix a pretty major bug

* Format code

* Big rethink of Variable API

* Make build-by-run the default build(), leveraging new zero_history KerasTensor mode

* Minor fixes

* Format code

* Switch back to build-by-eager-run for simplicity

* Add raise upon build failure

* Work around JAX bug.

* Add a few more tests.

* Add saving tests

* Adds test suite for SGD and golden correctness tests for all optimizers (keras-team#40)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Add binary accuracy (keras-team#41)

* chore: adding binary accuracy

* chore: fix docstring

* Add tests for add_loss and activity regularization.

* Reformat code

* Add ActivityRegularization layer

* Fix JAX CI.

* Add Lambda Callback (keras-team#42)

* Add LambdaCallback

* Add Lambda Callback

* Add Lambda Callback

* Rename lambda_callback_test.py

* Add einsum (keras-team#43)

* Add einsum

* address comments

* Fix format line length (keras-team#45)

* Add Embedding layer

* Shorten lines

* Add .vscode to .gitignore (keras-team#46)

* rm vscode settings

* add .vscode to gitignore

* Set demo program backend (keras-team#48)

* Add tests for training arg resolution in Layer.

* Implement mixed precision.

* Replace backend.execute with backend.numpy.XXX (keras-team#50)

* Add cosine similarity loss and update l2_normalize from regularizers (keras-team#34)

* Begin cosine loss

* Add testing for cosine similarity

* Fix formatting

* Docstring standardization

* Formatting

* Create numerical_utils

* Fix issue with call context lingering.

* Add the EarlyStopping callback (keras-team#44)

* add earlystopping callback

* addressing comments

* address comments

* addressing comments

* remove unused imports

* re-enable imports checks (keras-team#51)

* Add nn.one_hot (keras-team#52)

* Add GaussianDropout layer.

* Add GaussianNoise layer

* Add Categorical Accuracy Metric (keras-team#47)

* chore: adding categorical accuracy metric

* chore: reformat docstrings

* chore: reformat

* chore: ndims with len

* refactor the docstring

* Fix typos

* Implement masking.

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>

* Adds rmsprop optimizer and tests

* Add AdamW optimizer and tests, minor formatting changes

* Implemented formatting fixes

* Adds clip norm and clip value tests to Adam

* Adds Adagrad and Adadelta optimizers

* Applies fixes to formatting and deletes unnecessary kwargs

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
  • Loading branch information
7 people committed May 3, 2023
1 parent b805a5d commit 4212fdd
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 22 deletions.
121 changes: 121 additions & 0 deletions keras_core/optimizers/adadelta.py
@@ -0,0 +1,121 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer


@keras_core_export(["keras_core.optimizers.Adadelta"])
class Adadelta(optimizer.Optimizer):
"""Optimizer that implements the Adadelta algorithm.
Adadelta optimization is a stochastic gradient descent method that is based
on adaptive learning rate per dimension to address two drawbacks:
- The continual decay of learning rates throughout training.
- The need for a manually selected global learning rate.
Adadelta is a more robust extension of Adagrad that adapts learning rates
based on a moving window of gradient updates, instead of accumulating all
past gradients. This way, Adadelta continues learning even when many updates
have been done. Compared to Adagrad, in the original version of Adadelta you
don't have to set an initial learning rate. In this version, the initial
learning rate can be set, as in most other Keras optimizers.
Args:
learning_rate: Initial value for the learning rate: a floating
point value, Defaults to 0.001. Note that `Adadelta` tends
to benefit from higher initial learning rate values compared to
other optimizers.
To match the exact form in the original paper, use 1.0.
rho: A floating point value. The decay rate. Defaults to 0.95.
epsilon: Small floating point value used to maintain numerical
stability.
Defaults to 1e-7.
{{base_optimizer_keyword_args}}
Reference:
- [Zeiler, 2012](http://arxiv.org/abs/1212.5701)
"""

def __init__(
self,
learning_rate=0.001,
rho=0.95,
epsilon=1e-7,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="adadelta",
):
super().__init__(
learning_rate=learning_rate,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
name=name,
)
self.rho = rho
self.epsilon = epsilon

def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulated_grads = []
self._accumulated_delta_vars = []
for var in var_list:
self._accumulated_grads.append(
self.add_variable_from_reference(var, "accumulated_grad")
)
self._accumulated_delta_vars.append(
self.add_variable_from_reference(var, "accumulated_delta_var")
)

def update_step(self, grad, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
grad = ops.cast(grad, variable.dtype)

rho = self.rho
accumulated_grad = self._accumulated_grads[
self._get_variable_index(variable)
]
accumulated_delta_var = self._accumulated_delta_vars[
self._get_variable_index(variable)
]

def rms(x):
return ops.sqrt(x + self.epsilon)

accumulated_grad.assign(
rho * accumulated_grad + (1 - rho) * grad * grad
)
delta_var = -rms(accumulated_delta_var) * grad / rms(accumulated_grad)
accumulated_delta_var.assign(
rho * accumulated_delta_var + (1 - rho) * delta_var * delta_var
)
variable.assign(variable + lr * delta_var)

def get_config(self):
config = super().get_config()

config.update(
{
"rho": self.rho,
"epsilon": self.epsilon,
}
)
return config


Adadelta.__doc__ = Adadelta.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)
74 changes: 74 additions & 0 deletions keras_core/optimizers/adadelta_test.py
@@ -0,0 +1,74 @@
import numpy as np

from keras_core import backend
from keras_core import testing
from keras_core.optimizers.adadelta import Adadelta


class AdadeltaTest(testing.TestCase):
def test_config(self):
optimizer = Adadelta(
learning_rate=0.5,
rho=0.9,
epsilon=1e-5,
)
self.run_class_serialization_test(optimizer)

def test_single_step(self):
optimizer = Adadelta(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(
vars, [0.9993, 1.9993, 2.9993, 3.9993], rtol=1e-4, atol=1e-4
)

def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Adadelta(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))

optimizer_2 = Adadelta(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))

optimizer_3 = Adadelta(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))

self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)

def test_correctness_with_golden(self):
optimizer = Adadelta(learning_rate=1.0, rho=0.8, epsilon=1e-6)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)

golden = np.tile(
[[0.9978], [0.9947], [0.9915], [0.9882], [0.9849]], (1, 10)
)

optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))

def test_clip_norm(self):
optimizer = Adadelta(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])

def test_clip_value(self):
optimizer = Adadelta(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
107 changes: 107 additions & 0 deletions keras_core/optimizers/adagrad.py
@@ -0,0 +1,107 @@
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer


@keras_core_export(["keras_core.optimizers.Adagrad"])
class Adagrad(optimizer.Optimizer):
"""Optimizer that implements the Adagrad algorithm.
Adagrad is an optimizer with parameter-specific learning rates,
which are adapted relative to how frequently a parameter gets
updated during training. The more updates a parameter receives,
the smaller the updates.
Args:
learning_rate: Initial value for the learning rate:
a floating point value,
Defaults to 0.001.
Note that `Adagrad` tends to benefit from higher initial
learning rate values compared to other optimizers.
To match the exact form in the original paper, use 1.0.
initial_accumulator_value: Floating point value.
Starting value for the accumulators (per-parameter
momentum values).
Must be non-negative.
epsilon: Small floating point value used to maintain
numerical stability.
{{base_optimizer_keyword_args}}
Reference:
- [Duchi et al., 2011](
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
"""

def __init__(
self,
learning_rate=0.001,
initial_accumulator_value=0.1,
epsilon=1e-7,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="adagrad",
):
super().__init__(
learning_rate=learning_rate,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
name=name,
)
self.initial_accumulator_value = initial_accumulator_value
self.epsilon = epsilon

def build(self, var_list):
if self.built:
return
super().build(var_list)
self._accumulators = []
initializer = initializers.Constant(self.initial_accumulator_value)
for var in var_list:
self._accumulators.append(
self.add_variable(
shape=var.shape,
initializer=initializer,
dtype=var.dtype,
name="accumulator",
)
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)

accumulator = self._accumulators[self._get_variable_index(variable)]

accumulator.assign(accumulator + gradient * gradient)
variable.assign(
variable - (lr * gradient / ops.sqrt(accumulator + self.epsilon))
)

def get_config(self):
config = super().get_config()

config.update(
{
"initial_accumulator_value": self.initial_accumulator_value,
"epsilon": self.epsilon,
}
)
return config


Adagrad.__doc__ = Adagrad.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)
85 changes: 85 additions & 0 deletions keras_core/optimizers/adagrad_test.py
@@ -0,0 +1,85 @@
# flake8: noqa


import numpy as np

from keras_core import backend
from keras_core import testing
from keras_core.optimizers.adagrad import Adagrad


class AdagradTest(testing.TestCase):
def test_config(self):
optimizer = Adagrad(
learning_rate=0.5,
initial_accumulator_value=0.2,
epsilon=1e-5,
)
self.run_class_serialization_test(optimizer)

def test_single_step(self):
optimizer = Adagrad(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(
vars, [0.5233, 1.5007, 2.5005, 3.5061], rtol=1e-4, atol=1e-4
)

def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Adagrad(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))

optimizer_2 = Adagrad(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))

optimizer_3 = Adagrad(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))

self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)

def test_correctness_with_golden(self):
optimizer = Adagrad(
learning_rate=0.2, initial_accumulator_value=0.3, epsilon=1e-6
)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)

# fmt: off
golden = np.array(
[[0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963],
[0.9604, 0.9278, 0.9003, 0.8784, 0.8615, 0.8487, 0.8388, 0.8313, 0.8255, 0.8209],
[0.9251, 0.8629, 0.8137, 0.7768, 0.7497, 0.7298, 0.7151, 0.704, 0.6956, 0.6891],
[0.8903, 0.8012, 0.7342, 0.6862, 0.6521, 0.6277, 0.6099, 0.5967, 0.5867, 0.579],
[0.856, 0.7422, 0.6604, 0.6037, 0.5644, 0.5367, 0.5168, 0.5021, 0.491, 0.4825]]
)
# fmt: on

optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))

def test_clip_norm(self):
optimizer = Adagrad(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])

def test_clip_value(self):
optimizer = Adagrad(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

0 comments on commit 4212fdd

Please sign in to comment.