From ae54ba07cd192a92a16e3bf83837d2f59c999b32 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Fri, 28 Jun 2019 04:26:44 -0400 Subject: [PATCH] Add AdaMax optimizer --- Sources/TensorFlow/Optimizer.swift | 100 ++++++++++++++++++++ Tests/TensorFlowTests/SequentialTests.swift | 3 + 2 files changed, 103 insertions(+) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 7679e5c66..eeb17ecf3 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -132,6 +132,106 @@ public class Adam: Optimizer } } +/// AdaMax optimizer. +/// +/// A variant of Adam based on the infinity-norm. +/// +/// Reference: Section 7 of ["Adam - A Method for Stochastic Optimization"]( +/// https://arxiv.org/abs/1412.6980v8) +public class AdaMax: Optimizer + where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model + /// The learning rate. + public var learningRate: Float + /// Decay rate used to estimate the first moment (mean) of gradients. + public var beta1: Float + /// Decay rate used to estimate the exponentially weighted infinity norm. + public var beta2: Float + /// A small scalar added to the denominator to improve numerical stability. + public var epsilon: Float + /// The learning rate decay. + public var decay: Float + /// The step count. + public var step: Int = 0 + /// The first moments of the weights. + public var firstMoments: Model.TangentVector + /// The exponentially weighted infinity norm of the weights. + public var infinityNorm: Model.TangentVector + + /// Note: The default parameters follow those provided in the paper. + public init( + for model: __shared Model, + learningRate: Float = 0.002, + beta1: Float = 0.9, + beta2: Float = 0.999, + epsilon: Float = 1e-8, + decay: Float = 0 + ) { + precondition(learningRate >= 0, "Learning rate must be non-negative.") + precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1.") + precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1.") + precondition(decay >= 0, "Learning rate decay must be non-negative.") + + self.learningRate = learningRate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.decay = decay + + // Initialize first moments and infinity norm to be zeros of the same shape. + // We can't use `Model.AllDifferentiableVariables.zero` due to the + // interaction between Key Paths and Differentiable Arrays. + firstMoments = model.allDifferentiableVariables + infinityNorm = model.allDifferentiableVariables + for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp].resetToZero() + infinityNorm[keyPath: kp].resetToZero() + } + for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp].resetToZero() + infinityNorm[keyPath: kp].resetToZero() + } + } + + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. + public func update(_ model: inout Model.AllDifferentiableVariables, + along direction: Model.AllDifferentiableVariables) { + step += 1 + let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) + // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check + // this expression in reasonable time" error. + var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step))) + stepSize = stepSize / (1 - pow(beta1, Float(step))) + // Update `Tensor` & `Tensor` variables. + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp] = + (beta1 * firstMoments[keyPath: kp]) + (1 - beta1) * direction[keyPath: kp] + infinityNorm[keyPath: kp] = + max(beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp])) + let biasCorrection = stepSize / (1 - pow(beta1, Float(step))) + model[keyPath: kp] -= + biasCorrection * firstMoments[keyPath: kp] + / (infinityNorm[keyPath: kp] + Float(self.epsilon)) + } + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp] = + Double(beta1) * firstMoments[keyPath: kp] + + Double(1 - beta2) * direction[keyPath: kp] + infinityNorm[keyPath: kp] = + max(Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp])) + let biasCorrection = Double(stepSize) / Double(1 - pow(beta1, Float(step))) + model[keyPath: kp] -= + biasCorrection * firstMoments[keyPath: kp] + / (infinityNorm[keyPath: kp] + Double(self.epsilon)) + } + } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } +} + /// RMSProp optimizer. /// /// It is recommended to leave the parameters of this optimizer at their default values (except the diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index b09b5309b..313deda26 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -32,6 +32,7 @@ final class SequentialTests: XCTestCase { let sgd = SGD(for: model, learningRate: 0.02) let rmsprop = RMSProp(for: model, learningRate: 0.02) let adam = Adam(for: model, learningRate: 0.02) + let adamax = AdaMax(for: model, learningRate: 0.02) let adagrad = AdaGrad(for: model, learningRate: 0.02) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [0, 1, 1, 0] @@ -47,6 +48,8 @@ final class SequentialTests: XCTestCase { rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model) adam.update(&model, along: 𝛁model) adam.update(&model.allDifferentiableVariables, along: 𝛁model) + adamax.update(&model, along: 𝛁model) + adamax.update(&model.allDifferentiableVariables, along: 𝛁model) adagrad.update(&model, along: 𝛁model) adagrad.update(&model.allDifferentiableVariables, along: 𝛁model) }