From d9b92cafb5b5b796984ae15cbfa40b822a200f23 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Wed, 3 May 2023 21:16:00 +0000 Subject: [PATCH] Adds all remaining Keras optimizers (Adamax, Adafactor, Nadam, and Ftrl) (#80) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Sync with main (#56) * Minor touch ups * Fix a pretty major bug * Format code * Big rethink of Variable API * Make build-by-run the default build(), leveraging new zero_history KerasTensor mode * Minor fixes * Format code * Switch back to build-by-eager-run for simplicity * Add raise upon build failure * Work around JAX bug. * Add a few more tests. * Add saving tests * Adds test suite for SGD and golden correctness tests for all optimizers (#40) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Add binary accuracy (#41) * chore: adding binary accuracy * chore: fix docstring * Add tests for add_loss and activity regularization. * Reformat code * Add ActivityRegularization layer * Fix JAX CI. * Add Lambda Callback (#42) * Add LambdaCallback * Add Lambda Callback * Add Lambda Callback * Rename lambda_callback_test.py * Add einsum (#43) * Add einsum * address comments * Fix format line length (#45) * Add Embedding layer * Shorten lines * Add .vscode to .gitignore (#46) * rm vscode settings * add .vscode to gitignore * Set demo program backend (#48) * Add tests for training arg resolution in Layer. * Implement mixed precision. * Replace backend.execute with backend.numpy.XXX (#50) * Add cosine similarity loss and update l2_normalize from regularizers (#34) * Begin cosine loss * Add testing for cosine similarity * Fix formatting * Docstring standardization * Formatting * Create numerical_utils * Fix issue with call context lingering. * Add the EarlyStopping callback (#44) * add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports * re-enable imports checks (#51) * Add nn.one_hot (#52) * Add GaussianDropout layer. * Add GaussianNoise layer * Add Categorical Accuracy Metric (#47) * chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring * Fix typos * Implement masking. --------- Co-authored-by: Francois Chollet Co-authored-by: Aritra Roy Gosthipaty Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> * Adds rmsprop optimizer and tests * Add AdamW optimizer and tests, minor formatting changes * Implemented formatting fixes * Adds clip norm and clip value tests to Adam * Adds Adagrad and Adadelta optimizers * Applies fixes to formatting and deletes unnecessary kwargs * Adds Adamax and Adafactor and associated tests * Adds Nadam and Ftrl optimizers and associated tests --------- Co-authored-by: Francois Chollet Co-authored-by: Aritra Roy Gosthipaty Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> --- keras_core/optimizers/adafactor.py | 190 ++++++++++++++++++++ keras_core/optimizers/adafactor_test.py | 93 ++++++++++ keras_core/optimizers/adamax.py | 141 +++++++++++++++ keras_core/optimizers/adamax_test.py | 84 +++++++++ keras_core/optimizers/ftrl.py | 227 ++++++++++++++++++++++++ keras_core/optimizers/ftrl_test.py | 73 ++++++++ keras_core/optimizers/nadam.py | 158 +++++++++++++++++ keras_core/optimizers/nadam_test.py | 89 ++++++++++ 8 files changed, 1055 insertions(+) create mode 100644 keras_core/optimizers/adafactor.py create mode 100644 keras_core/optimizers/adafactor_test.py create mode 100644 keras_core/optimizers/adamax.py create mode 100644 keras_core/optimizers/adamax_test.py create mode 100644 keras_core/optimizers/ftrl.py create mode 100644 keras_core/optimizers/ftrl_test.py create mode 100644 keras_core/optimizers/nadam.py create mode 100644 keras_core/optimizers/nadam_test.py diff --git a/keras_core/optimizers/adafactor.py b/keras_core/optimizers/adafactor.py new file mode 100644 index 00000000000..fb721a2b5d8 --- /dev/null +++ b/keras_core/optimizers/adafactor.py @@ -0,0 +1,190 @@ +from keras_core import backend +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers import optimizer + + +@keras_core_export(["keras_core.optimizers.Adafactor"]) +class Adafactor(optimizer.Optimizer): + """Optimizer that implements the Adafactor algorithm. + + Adafactor is commonly used in NLP tasks, and has the advantage + of taking less memory because it only saves partial information of previous + gradients. + + The default argument setup is based on the original paper (see reference). + When gradients are of dimension > 2, Adafactor optimizer will delete the + last 2 dimensions separately in its accumulator variables. + + Args: + learning_rate: Initial value for the learning rate: + a floating point value, Defaults to 0.001. + beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`. + epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator + away from 0. + epsilon_2: float, defaults to 1e-3. A small offset to avoid learning + rate becoming too small by time. + clip_threshold: float, defaults to 1.0. Clipping threshold. This is a + part of Adafactor algorithm, independent from `clipnorm`, + `clipvalue`, and `global_clipnorm`. + relative_step: bool, defaults to True. If `learning_rate` is a + constant and `relative_step=True`, learning rate will be adjusted + based on current iterations. This is a default learning rate decay + in Adafactor. + {{base_optimizer_keyword_args}} + + Reference: + + - [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235). + + """ + + def __init__( + self, + learning_rate=0.001, + beta_2_decay=-0.8, + epsilon_1=1e-30, + epsilon_2=1e-3, + clip_threshold=1.0, + relative_step=True, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + name="adafactor", + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + ) + self.beta_2_decay = beta_2_decay + self.epsilon_1 = epsilon_1 + self.epsilon_2 = epsilon_2 + self.clip_threshold = clip_threshold + self.relative_step = relative_step + + def build(self, var_list): + """Initialize optimizer variables. + + Adam optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Args: + var_list: list of model variables to build Adam variables on. + """ + if self.built: + return + super().build(var_list) + self._r = [] + self._c = [] + self._v = [] + for var in var_list: + if len(var.shape) < 2: + # Don't factor if variable is of dimension < 2, but we still + # need to create dummy variables as placeholder. + self._r.append(backend.Variable(0, name=var.name)) + self._c.append(backend.Variable(0, name=var.name)) + else: + # Always factor the last 2 dimenstions. + r_shape = var.shape[:-1] + c_shape = var.shape[:-2] + var.shape[-1] + self._r.append( + self.add_variable( + shape=r_shape, + dtype=var.dtype, + name=var.name, + ) + ) + self._c.append( + self.add_variable( + shape=c_shape, + dtype=var.dtype, + name=var.name, + ) + ) + self._v.append( + self.add_variable_from_reference( + reference_variable=var, name="v" + ) + ) + + def _rms(self, x): + return ops.sqrt(ops.mean(ops.square(x))) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + epsilon_2 = ops.cast(self.epsilon_2, variable.dtype) + one = ops.cast(1.0, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + if self.relative_step: # TODO: add learning_rate_schedule logic + # If `relative_step=True` and learning rate is a constant, we + # apply the relative step algorithm. + lr = ops.minimum(lr, 1 / ops.sqrt(local_step)) + + r = self._r[self._get_variable_index(variable)] + c = self._c[self._get_variable_index(variable)] + v = self._v[self._get_variable_index(variable)] + + rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step)) + alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t + regulated_grad_square = ops.square(gradient) + self.epsilon_1 + beta_2_t = 1 - ops.power(local_step, self.beta_2_decay) + + if len(variable.shape) >= 2: + # `r` deletes the last dimension of gradient, so it is of shape + # `gradient.shape[:-1]`. + r.assign( + beta_2_t * r + + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1) + ) + # `c` deletes the second last dimension of gradient, so it is of + # shape `gradient.shape[:-2] + gradient.shape[-1]`. + c.assign( + beta_2_t * c + + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2) + ) + v.assign( + ops.expand_dims( + r / ops.mean(r, axis=-1, keepdims=True), axis=-1 + ) + * ops.expand_dims(c, -2) + ) + else: + v.assign(beta_2_t * v + (1 - beta_2_t) * regulated_grad_square) + + # `convert_to_tensor` unifies the handling of sparse and dense grads. + u_t = gradient / ops.sqrt(v) + u_t_hat = u_t / ops.maximum(one, (self._rms(u_t) / self.clip_threshold)) + variable.assign(variable - alpha_t * u_t_hat) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_2_decay": self.beta_2_decay, + "epsilon_1": self.epsilon_1, + "epsilon_2": self.epsilon_2, + "clip_threshold": self.clip_threshold, + "relative_step": self.relative_step, + } + ) + return config + + +Adafactor.__doc__ = Adafactor.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras_core/optimizers/adafactor_test.py b/keras_core/optimizers/adafactor_test.py new file mode 100644 index 00000000000..05c3d8af4c6 --- /dev/null +++ b/keras_core/optimizers/adafactor_test.py @@ -0,0 +1,93 @@ +# flake8: noqa + + +import numpy as np + +from keras_core import backend +from keras_core import testing +from keras_core.optimizers.adafactor import Adafactor + + +class AdafactorTest(testing.TestCase): + def test_config(self): + optimizer = Adafactor( + learning_rate=0.5, + beta_2_decay=-0.65, + epsilon_1=1e-15, + epsilon_2=1e-4, + clip_threshold=0.9, + relative_step=False, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adafactor(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [-0.3693, 0.6307, 1.6307, 2.6307], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + np.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adafactor( + learning_rate=0.5, + beta_2_decay=-0.65, + epsilon_1=1e-15, + epsilon_2=1e-4, + clip_threshold=0.9, + relative_step=False, + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55], + [0.3031, 0.3026, 0.3025, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024], + [0.1671, 0.1665, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663], + [0.0923, 0.0916, 0.0915, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914], + [0.0554, 0.0548, 0.0546, 0.0546, 0.0546, 0.0546, 0.0546, 0.0545, 0.0545, 0.0545]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adafactor(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adafactor(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras_core/optimizers/adamax.py b/keras_core/optimizers/adamax.py new file mode 100644 index 00000000000..d082799321d --- /dev/null +++ b/keras_core/optimizers/adamax.py @@ -0,0 +1,141 @@ +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers import optimizer + + +@keras_core_export(["keras_core.optimizers.Adamax"]) +class Adamax(optimizer.Optimizer): + """Optimizer that implements the Adamax algorithm. + + Adamax, a variant of Adam based on the infinity norm, is a first-order + gradient-based optimization method. Due to its capability of adjusting the + learning rate based on data characteristics, it is suited to learn + time-variant process, e.g., speech data with dynamically changed noise + conditions. Default parameters follow those provided in the paper (see + references below). + + Initialization: + + ```python + m = 0 # Initialize initial 1st moment vector + u = 0 # Initialize the exponentially weighted infinity norm + t = 0 # Initialize timestep + ``` + + The update rule for parameter `w` with gradient `g` is described at the end + of section 7.1 of the paper (see the referenece section): + + ```python + t += 1 + m = beta1 * m + (1 - beta) * g + u = max(beta2 * u, abs(g)) + current_lr = learning_rate / (1 - beta1 ** t) + w = w - current_lr * m / (u + epsilon) + ``` + + Args: + learning_rate: A floating point value, or a callable + that takes no arguments and returns the actual value to use. The + learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the exponentially weighted infinity norm. + epsilon: A small constant for numerical stability. + {{base_optimizer_keyword_args}} + + Reference: + + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + name="adamax", + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + + def build(self, var_list): + """Initialize optimizer variables. + + Adamax optimizer has 2 types of variables: momentums (denoted as m), + exponentially weighted infinity norm (denoted as u). + + Args: + var_list: list of model variables to build Adamax variables on. + """ + if self.built: + return + super().build(var_list) + self._m = [] + self._u = [] + for var in var_list: + self._m.append( + self.add_variable_from_reference( + reference_variable=var, name="m" + ) + ) + self._u.append( + self.add_variable_from_reference( + reference_variable=var, name="u" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + beta_1_power = ops.power( + ops.cast(self.beta_1, variable.dtype), local_step + ) + + m = self._m[self._get_variable_index(variable)] + u = self._u[self._get_variable_index(variable)] + + m.assign(m + (gradient - m) * (1 - self.beta_1)) + u.assign(ops.maximum(self.beta_2 * u, ops.abs(gradient))) + variable.assign( + variable - (lr * m) / ((1 - beta_1_power) * (u + self.epsilon)) + ) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + } + ) + return config + + +Adamax.__doc__ = Adamax.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras_core/optimizers/adamax_test.py b/keras_core/optimizers/adamax_test.py new file mode 100644 index 00000000000..61cccd012ba --- /dev/null +++ b/keras_core/optimizers/adamax_test.py @@ -0,0 +1,84 @@ +# flake8: noqa + + +import numpy as np + +from keras_core import backend +from keras_core import testing +from keras_core.optimizers.adamax import Adamax + + +class AdamaxTest(testing.TestCase): + def test_config(self): + optimizer = Adamax( + learning_rate=0.5, + beta_1=0.8, + beta_2=0.95, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adamax(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + np.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adamax( + learning_rate=0.2, beta_1=0.85, beta_2=0.95, epsilon=1e-6 + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8], + [0.6827, 0.6873, 0.6888, 0.6896, 0.6901, 0.6904, 0.6906, 0.6908, 0.6909, 0.691], + [0.5333, 0.5407, 0.5431, 0.5444, 0.5451, 0.5456, 0.546, 0.5462, 0.5464, 0.5466], + [0.368, 0.3773, 0.3804, 0.382, 0.3829, 0.3835, 0.384, 0.3843, 0.3846, 0.3848], + [0.1933, 0.204, 0.2076, 0.2094, 0.2105, 0.2112, 0.2117, 0.2121, 0.2124, 0.2126]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adamax(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adamax(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras_core/optimizers/ftrl.py b/keras_core/optimizers/ftrl.py new file mode 100644 index 00000000000..ea179f61939 --- /dev/null +++ b/keras_core/optimizers/ftrl.py @@ -0,0 +1,227 @@ +from keras_core import initializers +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers import optimizer + + +@keras_core_export(["keras_core.optimizers.Ftrl"]) +class Ftrl(optimizer.Optimizer): + r"""Optimizer that implements the FTRL algorithm. + + "Follow The Regularized Leader" (FTRL) is an optimization algorithm + developed at Google for click-through rate prediction in the early 2010s. It + is most suitable for shallow models with large and sparse feature spaces. + The algorithm is described by + [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). + The Keras version has support for both online L2 regularization + (the L2 regularization described in the paper + above) and shrinkage-type L2 regularization + (which is the addition of an L2 penalty to the loss function). + + Initialization: + + ```python + n = 0 + sigma = 0 + z = 0 + ``` + + Update rule for one variable `w`: + + ```python + prev_n = n + n = n + g ** 2 + sigma = (n ** -lr_power - prev_n ** -lr_power) / lr + z = z + g - sigma * w + if abs(z) < lambda_1: + w = 0 + else: + w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) + ``` + + Notation: + + - `lr` is the learning rate + - `g` is the gradient for the variable + - `lambda_1` is the L1 regularization strength + - `lambda_2` is the L2 regularization strength + - `lr_power` is the power to scale n. + + Check the documentation for the `l2_shrinkage_regularization_strength` + parameter for more details when shrinkage is enabled, in which case gradient + is replaced with a gradient with shrinkage. + + Args: + learning_rate: A floating point value, or a callable that + takes no arguments and returns the actual value to use. The learning + rate. Defaults to `0.001`. + learning_rate_power: A float value, must be less or equal to zero. + Controls how the learning rate decreases during training. Use zero + for a fixed learning rate. + initial_accumulator_value: The starting value for accumulators. Only + zero or positive values are allowed. + l1_regularization_strength: A float value, must be greater than or equal + to zero. Defaults to `0.0`. + l2_regularization_strength: A float value, must be greater than or equal + to zero. Defaults to `0.0`. + l2_shrinkage_regularization_strength: A float value, must be greater + than or equal to zero. This differs from L2 above in that the L2 + above is a stabilization penalty, whereas this L2 shrinkage is a + magnitude penalty. When input is sparse shrinkage will only happen + on the active weights. + beta: A float value, representing the beta value from the paper. + Defaults to 0.0. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + learning_rate=0.001, + learning_rate_power=-0.5, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0, + l2_shrinkage_regularization_strength=0.0, + beta=0.0, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + name="ftrl", + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + ) + + if initial_accumulator_value < 0.0: + raise ValueError( + "`initial_accumulator_value` needs to be positive or zero. " + "Received: initial_accumulator_value=" + f"{initial_accumulator_value}." + ) + if learning_rate_power > 0.0: + raise ValueError( + "`learning_rate_power` needs to be negative or zero. Received: " + f"learning_rate_power={learning_rate_power}." + ) + if l1_regularization_strength < 0.0: + raise ValueError( + "`l1_regularization_strength` needs to be positive or zero. " + "Received: l1_regularization_strength=" + f"{l1_regularization_strength}." + ) + if l2_regularization_strength < 0.0: + raise ValueError( + "`l2_regularization_strength` needs to be positive or zero. " + "Received: l2_regularization_strength=" + f"{l2_regularization_strength}." + ) + if l2_shrinkage_regularization_strength < 0.0: + raise ValueError( + "`l2_shrinkage_regularization_strength` needs to be positive " + "or zero. Received: l2_shrinkage_regularization_strength" + f"={l2_shrinkage_regularization_strength}." + ) + + self.learning_rate_power = learning_rate_power + self.initial_accumulator_value = initial_accumulator_value + self.l1_regularization_strength = l1_regularization_strength + self.l2_regularization_strength = l2_regularization_strength + self.l2_shrinkage_regularization_strength = ( + l2_shrinkage_regularization_strength + ) + self.beta = beta + + def build(self, var_list): + """Initialize optimizer variables. + + Args: + var_list: list of model variables to build Ftrl variables on. + """ + if self.built: + return + super().build(var_list) + self._accumulators = [] + self._linears = [] + for var in var_list: + self._accumulators.append( + self.add_variable( + shape=var.shape, + dtype=var.dtype, + name="accumulator", + initializer=initializers.Constant( + self.initial_accumulator_value, + ), + ) + ) + self._linears.append( + self.add_variable_from_reference( + reference_variable=var, name="linear" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + + accum = self._accumulators[self._get_variable_index(variable)] + linear = self._linears[self._get_variable_index(variable)] + + lr_power = self.learning_rate_power + l2_reg = self.l2_regularization_strength + l2_reg = l2_reg + self.beta / (2.0 * lr) + + # Ftrl optimizer has the same implementation for sparse and dense + # gradients update. + grad_to_use = ( + gradient + 2 * self.l2_shrinkage_regularization_strength * variable + ) + new_accum = accum + ops.power(gradient, 2) + linear.assign( + linear + + grad_to_use + - (ops.power(new_accum, -lr_power) - ops.power(accum, -lr_power)) + / lr + * variable + ) + quadratic = ops.power(new_accum, (-lr_power)) / lr + 2 * l2_reg + linear_clipped = ops.clip( + linear, + -self.l1_regularization_strength, + self.l1_regularization_strength, + ) + variable.assign((linear_clipped - linear) / quadratic) + accum.assign(new_accum) + + def get_config(self): + config = super().get_config() + + config.update( + { + "learning_rate_power": self.learning_rate_power, + "initial_accumulator_value": self.initial_accumulator_value, + "l1_regularization_strength": self.l1_regularization_strength, + "l2_regularization_strength": self.l2_regularization_strength, + "l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501 + "beta": self.beta, + } + ) + return config + + +Ftrl.__doc__ = Ftrl.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras_core/optimizers/ftrl_test.py b/keras_core/optimizers/ftrl_test.py new file mode 100644 index 00000000000..1fd9975ab22 --- /dev/null +++ b/keras_core/optimizers/ftrl_test.py @@ -0,0 +1,73 @@ +# flake8: noqa + + +import numpy as np + +from keras_core import backend +from keras_core import testing +from keras_core.optimizers.ftrl import Ftrl + + +class FtrlTest(testing.TestCase): + def test_config(self): + optimizer = Ftrl( + learning_rate=0.05, + learning_rate_power=-0.2, + initial_accumulator_value=0.4, + l1_regularization_strength=0.05, + l2_regularization_strength=0.15, + l2_shrinkage_regularization_strength=0.01, + beta=0.3, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Ftrl(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.2218, 1.3954, 2.3651, 2.8814], rtol=1e-4, atol=1e-4 + ) + + def test_correctness_with_golden(self): + optimizer = Ftrl( + learning_rate=0.05, + learning_rate_power=-0.2, + initial_accumulator_value=0.4, + l1_regularization_strength=0.05, + l2_regularization_strength=0.15, + l2_shrinkage_regularization_strength=0.01, + beta=0.3, + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [-0.0034, -0.0077, -0.0118, -0.0157, -0.0194, -0.023, -0.0263, -0.0294, -0.0325, -0.0354], + [-0.0078, -0.0162, -0.0242, -0.0317, -0.0387, -0.0454, -0.0516, -0.0575, -0.0631, -0.0685], + [-0.0121, -0.0246, -0.0363, -0.0472, -0.0573, -0.0668, -0.0757, -0.0842, -0.0922, -0.0999], + [-0.0164, -0.0328, -0.0481, -0.0623, -0.0753, -0.0875, -0.099, -0.1098, -0.1201, -0.1299]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Ftrl(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Ftrl(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras_core/optimizers/nadam.py b/keras_core/optimizers/nadam.py new file mode 100644 index 00000000000..0f617b24cdd --- /dev/null +++ b/keras_core/optimizers/nadam.py @@ -0,0 +1,158 @@ +from keras_core import backend +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers import optimizer + + +@keras_core_export(["keras_core.optimizers.Nadam"]) +class Nadam(optimizer.Optimizer): + """Optimizer that implements the Nadam algorithm. + + Much like Adam is essentially RMSprop with momentum, Nadam is Adam with + Nesterov momentum. + + Args: + learning_rate: A floating point value or a callable + that takes no arguments and returns the actual value to use. The + learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. + Defaults to `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. Defaults to + `0.999`. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + Defaults to `1e-7`. + {{base_optimizer_keyword_args}} + + Reference: + + - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + name="nadam", + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + + def build(self, var_list): + """Initialize optimizer variables. + + Nadam optimizer has 2 types of variables: momentums and velocities. + + Args: + var_list: list of model variables to build Nadam variables on. + """ + if self.built: + return + super().build(var_list) + self._momentums = [] + self._velocities = [] + self._u_product = backend.Variable(1.0, dtype=var_list[0].dtype) + # Keep a counter on how many times of _u_product has been computed to + # avoid duplicated computations. + self._u_product_counter = 1 + + for var in var_list: + self._momentums.append( + self.add_variable_from_reference( + reference_variable=var, name="m" + ) + ) + self._velocities.append( + self.add_variable_from_reference( + reference_variable=var, name="v" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + var_dtype = variable.dtype + lr = ops.cast(learning_rate, var_dtype) + gradient = ops.cast(gradient, var_dtype) + + local_step = ops.cast(self.iterations + 1, var_dtype) + next_step = ops.cast(self.iterations + 2, var_dtype) + decay = ops.cast(0.96, var_dtype) + beta_1 = ops.cast(self.beta_1, var_dtype) + beta_2 = ops.cast(self.beta_2, var_dtype) + u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step))) + u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step))) + + def get_cached_u_product(): + return self._u_product + + def compute_new_u_product(): + u_product_t = self._u_product * u_t + self._u_product.assign(u_product_t) + self._u_product_counter += 1 + return u_product_t + + if self._u_product_counter == (self.iterations + 2): + u_product_t = get_cached_u_product() + else: + u_product_t = compute_new_u_product() + + u_product_t_1 = u_product_t * u_t_1 + beta_2_power = ops.power(beta_2, local_step) + + m = self._momentums[self._get_variable_index(variable)] + v = self._velocities[self._get_variable_index(variable)] + + m.assign(m + (gradient - m) * (1 - beta_1)) + v.assign(v + (ops.square(gradient) - v) * (1 - beta_2)) + m_hat = u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / ( + 1 - u_product_t + ) + v_hat = v / (1 - beta_2_power) + + variable.assign( + variable - (m_hat * lr) / (ops.sqrt(v_hat) + self.epsilon) + ) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + } + ) + return config + + +Nadam.__doc__ = Nadam.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras_core/optimizers/nadam_test.py b/keras_core/optimizers/nadam_test.py new file mode 100644 index 00000000000..88d140bf63a --- /dev/null +++ b/keras_core/optimizers/nadam_test.py @@ -0,0 +1,89 @@ +# flake8: noqa + + +import numpy as np + +from keras_core import backend +from keras_core import testing +from keras_core.optimizers.nadam import Nadam + + +class NadamTest(testing.TestCase): + def test_config(self): + optimizer = Nadam( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Nadam(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.4686, 1.4686, 2.4686, 3.4686], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + np.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Nadam( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281], + [-0.1738, -0.1731, -0.1726, -0.1723, -0.1721, -0.172, -0.1719, -0.1718, -0.1718, -0.1717], + [-0.7115, -0.7103, -0.7096, -0.7092, -0.709, -0.7088, -0.7086, -0.7085, -0.7085, -0.7084], + [-1.2335, -1.2322, -1.2313, -1.2309, -1.2306, -1.2304, -1.2302, -1.2301, -1.23, -1.2299], + [-1.7492, -1.7478, -1.7469, -1.7464, -1.7461, -1.7459, -1.7457, -1.7456, -1.7455, -1.7454]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Nadam(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Nadam(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0])