From e408f82d68482d4fbb1d9512c038ed78ffbaa14e Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 12 Jun 2019 00:26:46 -0700 Subject: [PATCH 01/12] [Optimizer] Simplify optimizers using generalized vector math. --- Sources/TensorFlow/Layer.swift | 6 ++++ Sources/TensorFlow/Operators/Math.swift | 8 +++-- Sources/TensorFlow/Optimizer.swift | 40 +++++-------------------- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index f634a80d2..e5e02aa36 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -82,6 +82,12 @@ public extension Layer { } } +public extension Layer { + mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { + lhs = lhs.moved(along: rhs) + } +} + public extension Differentiable { /// Returns the output computed by applying a sequence of layers to the previous layer's output, /// except that the first layer's input is `self`. diff --git a/Sources/TensorFlow/Operators/Math.swift b/Sources/TensorFlow/Operators/Math.swift index 6eba21680..f7980ea1e 100644 --- a/Sources/TensorFlow/Operators/Math.swift +++ b/Sources/TensorFlow/Operators/Math.swift @@ -31,8 +31,12 @@ func pow(_ x: T, _ y: T) -> T { // Vector Space //===------------------------------------------------------------------------------------------===// -extension Tensor: VectorProtocol where Scalar: Numeric { - public typealias VectorSpaceScalar = Scalar +extension Tensor: VectorProtocol where Scalar: TensorFlowFloatingPoint { + public typealias VectorSpaceScalar = Float + + public func scaled(by scale: Float) -> Self { + Scalar(scale) * self + } } //===------------------------------------------------------------------------------------------===// diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index e7446604e..ba8c3ff9d 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -201,8 +201,7 @@ public class RMSProp: Optimizer /// /// An optimizer that implements stochastic gradient descent, with support for momentum, learning /// rate decay, and Nesterov momentum. -public class SGD: Optimizer - where Model.AllDifferentiableVariables == Model.TangentVector { +public class SGD: Optimizer where Model.TangentVector: VectorProtocol { /// The learning rate. public var learningRate: Float /// The momentum factor. It accelerates stochastic gradient descent in the relevant direction @@ -212,8 +211,8 @@ public class SGD: Optimizer public var decay: Float /// Use Nesterov momentum if true. public var nesterov: Bool - /// The velocity state of the model - public var velocity: Model.AllDifferentiableVariables + /// The velocity state of the model. + public var velocity: Model.TangentVector = .zero /// The set of steps taken. public var step: Int = 0 @@ -232,40 +231,17 @@ public class SGD: Optimizer self.momentum = momentum self.decay = decay self.nesterov = nesterov - velocity = model.allDifferentiableVariables - for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp].resetToZero() - } - for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp].resetToZero() - } } public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp] = - momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp] - if nesterov { - model[keyPath: kp] += - momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp] - } else { - model[keyPath: kp] += velocity[keyPath: kp] - } - } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp] = - Double(momentum) * velocity[keyPath: kp] - - Double(learningRate) * direction[keyPath: kp] - if nesterov { - model[keyPath: kp] += - Double(momentum) * velocity[keyPath: kp] - Double(learningRate) * - direction[keyPath: kp] - } else { - model[keyPath: kp] += velocity[keyPath: kp] - } + velocity = momentum * velocity - direction.scaled(by: learningRate) + if nesterov { + model += momentum * velocity - direction.scaled(by: learningRate) + } else { + model += velocity } } } From 5873165341fee81dbd4782bd0cb8ff910db0cea2 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 12 Jun 2019 01:12:55 -0700 Subject: [PATCH 02/12] Make `+=` internal. --- Sources/TensorFlow/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index e5e02aa36..b1450a518 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -82,7 +82,7 @@ public extension Layer { } } -public extension Layer { +internal extension Layer { mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { lhs = lhs.moved(along: rhs) } From 1b15f31269079e35c2ab78813d2a39eaa0d4cfd1 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 12 Jun 2019 01:13:51 -0700 Subject: [PATCH 03/12] Relax the generic constraint on `Model` to just `Differentiable` --- Sources/TensorFlow/Optimizer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index ba8c3ff9d..ae3233b0a 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -201,7 +201,7 @@ public class RMSProp: Optimizer /// /// An optimizer that implements stochastic gradient descent, with support for momentum, learning /// rate decay, and Nesterov momentum. -public class SGD: Optimizer where Model.TangentVector: VectorProtocol { +public class SGD: Optimizer where Model.TangentVector: VectorProtocol { /// The learning rate. public var learningRate: Float /// The momentum factor. It accelerates stochastic gradient descent in the relevant direction From aac42754a6c7ee36225b24fc94d5c69f42aeb8e5 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 12 Jun 2019 01:53:39 -0700 Subject: [PATCH 04/12] Update all optimizers. --- Sources/TensorFlow/Layer.swift | 4 + Sources/TensorFlow/Optimizer.swift | 142 +++++++---------------------- 2 files changed, 36 insertions(+), 110 deletions(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index b1450a518..2add37dfe 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -86,6 +86,10 @@ internal extension Layer { mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { lhs = lhs.moved(along: rhs) } + + mutating func -= (lhs: inout Self, rhs: TangentVector) -> Self { + lhs = lhs.moved(along: .zero - rhs) + } } public extension Differentiable { diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index ae3233b0a..c8b7f4c41 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -25,8 +25,7 @@ public protocol Optimizer { var learningRate: Scalar { get set } /// Updates the specified differentiable variables along the specified /// direction. - mutating func update(_ variables: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) + mutating func update(_ variables: inout Model, along direction: Model.TangentVector) } fileprivate extension Tensor where Scalar: Numeric { @@ -35,14 +34,11 @@ fileprivate extension Tensor where Scalar: Numeric { } } -// MARK: - Key-path based optimizers - /// Adam optimizer. /// /// Reference: ["Adam - A Method for Stochastic Optimization"]( /// https://arxiv.org/abs/1412.6980v8) -public class Adam: Optimizer - where Model.AllDifferentiableVariables == Model.TangentVector { +public class Adam: Optimizer where Model.TangentVector: VectorProtocol { /// The learning rate. public var learningRate: Float /// A coefficient used to calculate the first and second moments of @@ -58,9 +54,9 @@ public class Adam: Optimizer /// The current step. public var step: Int = 0 /// The first moments of the weights. - public var firstMoments: Model.AllDifferentiableVariables + public var firstMoments: Model.TangentVector = .zero /// The second moments of the weights. - public var secondMoments: Model.AllDifferentiableVariables + public var secondMoments: Model.TangentVector = .zero public init( for model: __shared Model, @@ -80,52 +76,20 @@ public class Adam: Optimizer self.beta2 = beta2 self.epsilon = epsilon self.decay = decay - - // Initialize first & second moments to be zeros of the same shape. - // We can't use `Model.AllDifferentiableVariables.zero` due to the - // interaction between Key Paths and Differentiable Arrays. - firstMoments = model.allDifferentiableVariables - secondMoments = model.allDifferentiableVariables - for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { - firstMoments[keyPath: kp].resetToZero() - secondMoments[keyPath: kp].resetToZero() - } - for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { - firstMoments[keyPath: kp].resetToZero() - secondMoments[keyPath: kp].resetToZero() - } } - public func update(_ model: inout Model.AllDifferentiableVariables, - along direction: Model.AllDifferentiableVariables) { + public func update(_ model: inout Model, along direction: Model.TangentVector) { step += 1 - let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) + let learningRate = self.learningRate / (1 + decay * Float(step)) // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check // this expression in reasonable time" error. var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step))) stepSize = stepSize / (1 - pow(beta1, Float(step))) // Update Float & Double Tensor variables. - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - firstMoments[keyPath: kp] = - firstMoments[keyPath: kp] * beta1 + (1 - beta1) * direction[keyPath: kp] - secondMoments[keyPath: kp] = - secondMoments[keyPath: kp] * beta2 + (1 - beta2) * - direction[keyPath: kp] * direction[keyPath: kp] - model[keyPath: kp] -= - stepSize * firstMoments[keyPath: kp] / (sqrt(secondMoments[keyPath: kp]) + epsilon) - } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - firstMoments[keyPath: kp] = - firstMoments[keyPath: kp] * Double(beta1) + - Double((1 - beta1)) * direction[keyPath: kp] - secondMoments[keyPath: kp] = - secondMoments[keyPath: kp] * Double(beta2) + Double(1 - beta2) * - direction[keyPath: kp] * direction[keyPath: kp] - model[keyPath: kp] -= - Double(stepSize) * firstMoments[keyPath: kp] / - sqrt(secondMoments[keyPath: kp]) + Double(epsilon) - } + firstMoments = firstMoments * beta1 + (1 - beta1) * direction + secondMoments = secondMoments * beta2 + (1 - beta2) * direction * direction + model -= stepSize * firstMoments / (sqrt(secondMoments) + epsilon) } } @@ -137,8 +101,7 @@ public class Adam: Optimizer /// /// Reference: ["rmsprop: Divide the gradient by a running average of its recent magnitude"]( /// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -public class RMSProp: Optimizer - where Model.AllDifferentiableVariables == Model.TangentVector { +public class RMSProp: Optimizer where Model.TangentVector: VectorProtocol { /// The learning rate. public var learningRate: Float // TODO: Document `rho`. Keras doesn't document `rho`. @@ -150,7 +113,7 @@ public class RMSProp: Optimizer /// The step count. public var step: Float = 0 /// The alpha values for all model differentiable variables. - public var alpha: Model.AllDifferentiableVariables + public var alpha: Model.TangentVector = .zero public init( for model: __shared Model, @@ -167,33 +130,14 @@ public class RMSProp: Optimizer self.rho = rho self.epsilon = epsilon self.decay = decay - alpha = model.allDifferentiableVariables - for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp].resetToZero() - } - for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp].resetToZero() - } } - public func update(_ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) { + public func update(_ model: inout Model, along direction: Model.TangentVector) { step += 1 - let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp] = - rho * alpha[keyPath: kp] + (1 - rho) * pow(direction[keyPath: kp], 2) - model[keyPath: kp] -= - learningRate * direction[keyPath: kp] / (sqrt(alpha[keyPath: kp]) + epsilon) - } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp] = - Double(rho) * alpha[keyPath: kp] + Double(1 - rho) * pow(direction[keyPath: kp], 2) - model[keyPath: kp] -= - Double(learningRate) * direction[keyPath: kp] / - (sqrt(alpha[keyPath: kp]) + Double(epsilon)) - } + let learningRate = self.learningRate / (1 + decay * Float(step)) + alpha = rho * alpha + (1 - rho) * pow(direction, 2) + model -= learningRate * direction / (sqrt(alpha) + epsilon) } } @@ -233,13 +177,12 @@ public class SGD: Optimizer where Model.TangentVector: Ve self.nesterov = nesterov } - public func update(_ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) { + public func update(_ model: inout Model, along direction: Model.TangentVector) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) - velocity = momentum * velocity - direction.scaled(by: learningRate) + velocity = momentum * velocity - direction * learningRate if nesterov { - model += momentum * velocity - direction.scaled(by: learningRate) + model += momentum * velocity - direction * learningRate } else { model += velocity } @@ -249,12 +192,12 @@ public class SGD: Optimizer where Model.TangentVector: Ve // MARK: - Manifold optimizers /// A Riemann manifold stochastic gradient descent (SGD) optimizer. -public class RiemannSGD: Optimizer - where Model.TangentVector: VectorProtocol, Model.TangentVector.VectorSpaceScalar == Scalar { +public class RiemannSGD: Optimizer + where Model.TangentVector: VectorProtocol { /// The learning rate. - public var learningRate: Scalar + public var learningRate: Model.TangentVector.Scalar - public init(learningRate: Scalar) { + public init(learningRate: Model.TangentVector.Scalar) { self.learningRate = learningRate } @@ -265,22 +208,20 @@ public class RiemannSGD: Optimizer self.init(learningRate: learningRate) } - public func update(_ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) { + public func update(_ model: inout Model, along direction: Model.TangentVector) { model = model.moved(along: learningRate * (.zero - direction)) } } /// AdaGrad optimizer. /// -/// Individually adapts the learning rates of all model parameters by scaling them inversely proportional to +/// Individually adapts the learning rates of all model parameters by scaling them inversely proportional to /// the square root of the sum of all the historical squared values of the gradient. -/// +/// /// Reference: ["Adaptive Subgradient Methods for Online Learning and Stochastic Optimization"]( -/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) -/// -public class AdaGrad: Optimizer - where Model.AllDifferentiableVariables == Model.TangentVector { +/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) +/// +public class AdaGrad: Optimizer where Model.TangentVector: VectorProtocol { /// The learning rate. public var learningRate: Float /// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2, @@ -289,7 +230,7 @@ public class AdaGrad: Optimizer /// A small scalar added to the denominator to improve numerical stability. public var epsilon: Float /// The alpha values for all model differentiable variables. - public var alpha: Model.AllDifferentiableVariables + public var alpha: Model.TangentVector = .zero public init( for model: __shared Model, @@ -303,29 +244,10 @@ public class AdaGrad: Optimizer self.learningRate = learningRate self.rho = rho self.epsilon = epsilon - - alpha = model.allDifferentiableVariables - for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp].resetToZero() - } - for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp].resetToZero() - } } - public func update(_ model: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) { - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp] = rho + direction[keyPath: kp].squared() - model[keyPath: kp] -= - learningRate * direction[keyPath: kp] / (sqrt(alpha[keyPath: kp] + epsilon)) - } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - alpha[keyPath: kp] = Double(rho) + direction[keyPath: kp].squared() - model[keyPath: kp] -= - Double(learningRate) * direction[keyPath: kp] / - (sqrt(alpha[keyPath: kp] + Double(epsilon))) - } + public func update(_ model: inout Model, along direction: Model.TangentVector) { + alpha = rho + direction.squared() + model -= learningRate * direction / (sqrt(alpha + epsilon)) } } - From 1846806de57d21fd8faf39902e0d3d6f1aa1daad Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 12 Jun 2019 01:56:42 -0700 Subject: [PATCH 05/12] Make -= be defined with a SignedNumeric constraint. --- Sources/TensorFlow/Layer.swift | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index 2add37dfe..84e62b482 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -82,13 +82,16 @@ public extension Layer { } } -internal extension Layer { +internal extension Layer where TangentVector: VectorProtocol { mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { lhs = lhs.moved(along: rhs) } +} +internal extension Layer where TangentVector: VectorProtocol, + TangentVector.VectorSpaceScalar: SignedNumeric { mutating func -= (lhs: inout Self, rhs: TangentVector) -> Self { - lhs = lhs.moved(along: .zero - rhs) + lhs = lhs.moved(along: -rhs) } } From 498b66867e38771a148e37212ce1a6462ae215b7 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 14 Jun 2019 16:39:44 -0700 Subject: [PATCH 06/12] Extend 'Differentiable' to have '+=' and '-=' internally. --- Sources/TensorFlow/Layer.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index 84e62b482..4a1190933 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -82,13 +82,13 @@ public extension Layer { } } -internal extension Layer where TangentVector: VectorProtocol { +internal extension Differentiable where TangentVector: VectorProtocol { mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { lhs = lhs.moved(along: rhs) } } -internal extension Layer where TangentVector: VectorProtocol, +internal extension Differentiable where TangentVector: VectorProtocol, TangentVector.VectorSpaceScalar: SignedNumeric { mutating func -= (lhs: inout Self, rhs: TangentVector) -> Self { lhs = lhs.moved(along: -rhs) From 92a43a19e6865fe8bcb7dffff5713352e22bdd65 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 15 Jun 2019 00:08:02 -0700 Subject: [PATCH 07/12] Fix `Differentiable` `+=` and `-=` extensions. However, this causes a type checking ambiguity. --- Sources/TensorFlow/Layer.swift | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Sources/TensorFlow/Layer.swift b/Sources/TensorFlow/Layer.swift index 4a1190933..6c2e99707 100644 --- a/Sources/TensorFlow/Layer.swift +++ b/Sources/TensorFlow/Layer.swift @@ -83,15 +83,15 @@ public extension Layer { } internal extension Differentiable where TangentVector: VectorProtocol { - mutating func += (lhs: inout Self, rhs: TangentVector) -> Self { - lhs = lhs.moved(along: rhs) + static func += (lhs: inout Self, rhs: TangentVector) { + lhs.move(along: rhs) } } internal extension Differentiable where TangentVector: VectorProtocol, - TangentVector.VectorSpaceScalar: SignedNumeric { - mutating func -= (lhs: inout Self, rhs: TangentVector) -> Self { - lhs = lhs.moved(along: -rhs) + TangentVector.VectorSpaceScalar: SignedNumeric { + static func -= (lhs: inout Self, rhs: TangentVector) { + lhs.move(along: -rhs) } } From c1e5fe7e4644307d807c2e14de9b8fceb43478d0 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 25 Jun 2019 00:07:55 -0700 Subject: [PATCH 08/12] Update math operators. --- Sources/TensorFlow/Operators/Math.swift | 46 ++++++++++++++++++++----- Sources/TensorFlow/Optimizer.swift | 7 ++-- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/Sources/TensorFlow/Operators/Math.swift b/Sources/TensorFlow/Operators/Math.swift index 4e4e15311..55098085a 100644 --- a/Sources/TensorFlow/Operators/Math.swift +++ b/Sources/TensorFlow/Operators/Math.swift @@ -164,23 +164,51 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint { extension Tensor: VectorProtocol where Scalar: TensorFlowFloatingPoint { public typealias VectorSpaceScalar = Float + // @differentiable(where Scalar: TensorFlowFloatingPoint) public func scaled(by scale: Float) -> Self { Scalar(scale) * self } - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func adding(_ scalar: Scalar) -> Self { - self + scalar + // @differentiable(where Scalar: TensorFlowFloatingPoint) + public func adding(_ scalar: Float) -> Self { + self + Scalar(scalar) } - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func subtracting(_ scalar: Scalar) -> Self { - self - scalar + // @differentiable(where Scalar: TensorFlowFloatingPoint) + public func subtracting(_ scalar: Float) -> Self { + self - Scalar(scalar) } +} - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func scaled(by scalar: Scalar) -> Self { - self * scalar +extension VectorProtocol { + static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self { + rhs.adding(lhs) + } + + static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.adding(rhs) + } + + static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.subtracting(rhs) + } + + static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self { + rhs.scaled(by: lhs) + } + + static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.scaled(by: rhs) + } +} + +extension VectorProtocol where VectorSpaceScalar: SignedNumeric { + static prefix func - (x: Self) -> Self { + .zero - x + } + + static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self { + (-rhs).adding(lhs) } } diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 97088c6c8..687c5cff5 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -194,11 +194,12 @@ public class SGD: Optimizer where Model.TangentVector: Ve /// A Riemann manifold stochastic gradient descent (SGD) optimizer. public class RiemannSGD: Optimizer where Model.TangentVector: VectorProtocol, - Model.TangentVector.Scalar: SignedNumeric { + Model.TangentVector.VectorSpaceScalar: FloatingPoint { + public typealias Scalar = Model.TangentVector.VectorSpaceScalar /// The learning rate. - public var learningRate: Model.TangentVector.Scalar + public var learningRate: Model.TangentVector.VectorSpaceScalar - public init(learningRate: Model.TangentVector.Scalar) { + public init(learningRate: Model.TangentVector.VectorSpaceScalar) { self.learningRate = learningRate } From 406d860249b166570fea888c7607b7b58ea3a5ad Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 26 Jun 2019 17:16:29 -0700 Subject: [PATCH 09/12] Revert Adam to using key paths. --- Sources/TensorFlow/Optimizer.swift | 49 +++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 687c5cff5..6ec076a9f 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -38,7 +38,8 @@ fileprivate extension Tensor where Scalar: Numeric { /// /// Reference: ["Adam - A Method for Stochastic Optimization"]( /// https://arxiv.org/abs/1412.6980v8) -public class Adam: Optimizer where Model.TangentVector: VectorProtocol { +public class Adam: Optimizer + where Model.AllDifferentiableVariables == Model.TangentVector { /// The learning rate. public var learningRate: Float /// A coefficient used to calculate the first and second moments of @@ -54,9 +55,9 @@ public class Adam: Optimizer where Model.TangentVector: V /// The current step. public var step: Int = 0 /// The first moments of the weights. - public var firstMoments: Model.TangentVector = .zero + public var firstMoments: Model.AllDifferentiableVariables /// The second moments of the weights. - public var secondMoments: Model.TangentVector = .zero + public var secondMoments: Model.AllDifferentiableVariables public init( for model: __shared Model, @@ -76,20 +77,52 @@ public class Adam: Optimizer where Model.TangentVector: V self.beta2 = beta2 self.epsilon = epsilon self.decay = decay + + // Initialize first & second moments to be zeros of the same shape. + // We can't use `Model.AllDifferentiableVariables.zero` due to the + // interaction between Key Paths and Differentiable Arrays. + firstMoments = model.allDifferentiableVariables + secondMoments = model.allDifferentiableVariables + for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp].resetToZero() + secondMoments[keyPath: kp].resetToZero() + } + for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp].resetToZero() + secondMoments[keyPath: kp].resetToZero() + } } - public func update(_ model: inout Model, along direction: Model.TangentVector) { + public func update(_ model: inout Model.AllDifferentiableVariables, + along direction: Model.AllDifferentiableVariables) { step += 1 - let learningRate = self.learningRate / (1 + decay * Float(step)) + let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check // this expression in reasonable time" error. var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step))) stepSize = stepSize / (1 - pow(beta1, Float(step))) // Update Float & Double Tensor variables. - firstMoments = firstMoments * beta1 + (1 - beta1) * direction - secondMoments = secondMoments * beta2 + (1 - beta2) * direction * direction - model -= stepSize * firstMoments / (sqrt(secondMoments) + epsilon) + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp] = + firstMoments[keyPath: kp] * beta1 + (1 - beta1) * direction[keyPath: kp] + secondMoments[keyPath: kp] = + secondMoments[keyPath: kp] * beta2 + (1 - beta2) * + direction[keyPath: kp] * direction[keyPath: kp] + model[keyPath: kp] -= + stepSize * firstMoments[keyPath: kp] / (sqrt(secondMoments[keyPath: kp]) + epsilon) + } + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + firstMoments[keyPath: kp] = + firstMoments[keyPath: kp] * Double(beta1) + + Double((1 - beta1)) * direction[keyPath: kp] + secondMoments[keyPath: kp] = + secondMoments[keyPath: kp] * Double(beta2) + Double(1 - beta2) * + direction[keyPath: kp] * direction[keyPath: kp] + model[keyPath: kp] -= + Double(stepSize) * firstMoments[keyPath: kp] / + sqrt(secondMoments[keyPath: kp]) + Double(epsilon) + } } } From dd8e75363be20a25066e16036a1d14f52b41bf2d Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 26 Jun 2019 19:52:19 -0700 Subject: [PATCH 10/12] Revert more to using key paths. --- Sources/TensorFlow/Loss.swift | 3 +- Sources/TensorFlow/Optimizer.swift | 73 ++++++++++++++----- Tests/TensorFlowTests/SequentialTests.swift | 2 +- Tests/TensorFlowTests/TrivialModelTests.swift | 2 +- 4 files changed, 60 insertions(+), 20 deletions(-) diff --git a/Sources/TensorFlow/Loss.swift b/Sources/TensorFlow/Loss.swift index 9252bf74e..76e146a40 100644 --- a/Sources/TensorFlow/Loss.swift +++ b/Sources/TensorFlow/Loss.swift @@ -231,6 +231,7 @@ public func sigmoidCrossEntropy( // This numerical stable implementation is based on tf.nn.sigmoid_cross_entropy_with_logits. let maxLogitsWithZero = max(logits, Tensor(0)) - let loss = maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits))) + var loss = maxLogitsWithZero - logits * labels + loss = loss + log(1 + exp(-abs(logits))) return loss.mean() } diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 6ec076a9f..badcadd9f 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -134,7 +134,8 @@ public class Adam: Optimizer /// /// Reference: ["rmsprop: Divide the gradient by a running average of its recent magnitude"]( /// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -public class RMSProp: Optimizer where Model.TangentVector: VectorProtocol { +public class RMSProp: Optimizer + where Model.AllDifferentiableVariables == Model.TangentVector { /// The learning rate. public var learningRate: Float // TODO: Document `rho`. Keras doesn't document `rho`. @@ -146,7 +147,7 @@ public class RMSProp: Optimizer where Model.TangentVector /// The step count. public var step: Float = 0 /// The alpha values for all model differentiable variables. - public var alpha: Model.TangentVector = .zero + public var alpha: Model.AllDifferentiableVariables public init( for model: __shared Model, @@ -154,7 +155,7 @@ public class RMSProp: Optimizer where Model.TangentVector rho: Float = 0.9, epsilon: Float = 1e-8, decay: Float = 0 - ) { + ) { precondition(learningRate >= 0, "Learning rate must be non-negative") precondition(rho >= 0, "Rho must be non-negative") precondition(decay >= 0, "Weight decay must be non-negative") @@ -163,14 +164,32 @@ public class RMSProp: Optimizer where Model.TangentVector self.rho = rho self.epsilon = epsilon self.decay = decay + alpha = model.allDifferentiableVariables + for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp].resetToZero() + } + for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp].resetToZero() + } } - - public func update(_ model: inout Model, along direction: Model.TangentVector) { + public func update(_ model: inout Model.AllDifferentiableVariables, + along direction: Model.TangentVector) { step += 1 - let learningRate = self.learningRate / (1 + decay * Float(step)) - alpha = rho * alpha + (1 - rho) * pow(direction, 2) - model -= learningRate * direction / (sqrt(alpha) + epsilon) + let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp] = + rho * alpha[keyPath: kp] + (1 - rho) * pow(direction[keyPath: kp], 2) + model[keyPath: kp] -= + learningRate * direction[keyPath: kp] / (sqrt(alpha[keyPath: kp]) + epsilon) + } + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp] = + Double(rho) * alpha[keyPath: kp] + Double(1 - rho) * pow(direction[keyPath: kp], 2) + model[keyPath: kp] -= + Double(learningRate) * direction[keyPath: kp] / + (sqrt(alpha[keyPath: kp]) + Double(epsilon)) + } } } @@ -178,7 +197,8 @@ public class RMSProp: Optimizer where Model.TangentVector /// /// An optimizer that implements stochastic gradient descent, with support for momentum, learning /// rate decay, and Nesterov momentum. -public class SGD: Optimizer where Model.TangentVector: VectorProtocol { +public class SGD: Optimizer + where Model.TangentVector: VectorProtocol & ElementaryFunctions, Model.TangentVector.VectorSpaceScalar == Float { /// The learning rate. public var learningRate: Float /// The momentum factor. It accelerates stochastic gradient descent in the relevant direction @@ -215,9 +235,9 @@ public class SGD: Optimizer where Model.TangentVector: Ve let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) velocity = momentum * velocity - direction * learningRate if nesterov { - model += momentum * velocity - direction * learningRate + model.move(along: momentum * velocity - direction * learningRate) } else { - model += velocity + model.move(along: velocity) } } } @@ -255,9 +275,10 @@ public class RiemannSGD: Optimizer /// the square root of the sum of all the historical squared values of the gradient. /// /// Reference: ["Adaptive Subgradient Methods for Online Learning and Stochastic Optimization"]( -/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) +/// http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) /// -public class AdaGrad: Optimizer where Model.TangentVector: VectorProtocol { +public class AdaGrad: Optimizer + where Model.AllDifferentiableVariables == Model.TangentVector { /// The learning rate. public var learningRate: Float /// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2, @@ -266,7 +287,7 @@ public class AdaGrad: Optimizer where Model.TangentVector /// A small scalar added to the denominator to improve numerical stability. public var epsilon: Float /// The alpha values for all model differentiable variables. - public var alpha: Model.TangentVector = .zero + public var alpha: Model.AllDifferentiableVariables public init( for model: __shared Model, @@ -280,10 +301,28 @@ public class AdaGrad: Optimizer where Model.TangentVector self.learningRate = learningRate self.rho = rho self.epsilon = epsilon + + alpha = model.allDifferentiableVariables + for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp].resetToZero() + } + for kp in alpha.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp].resetToZero() + } } - public func update(_ model: inout Model, along direction: Model.TangentVector) { - alpha = rho + direction.squared() - model.move(along: -learningRate * direction / (sqrt(alpha + epsilon))) + public func update(_ model: inout Model.AllDifferentiableVariables, + along direction: Model.TangentVector) { + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp] = rho + direction[keyPath: kp].squared() + model[keyPath: kp] -= + learningRate * direction[keyPath: kp] / (sqrt(alpha[keyPath: kp] + epsilon)) + } + for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { + alpha[keyPath: kp] = Double(rho) + direction[keyPath: kp].squared() + model[keyPath: kp] -= + Double(learningRate) * direction[keyPath: kp] / + (sqrt(alpha[keyPath: kp] + Double(epsilon))) + } } } diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index 8d305e9a1..7e27f76ad 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -38,7 +38,7 @@ final class SequentialTests: XCTestCase { let ŷ = model(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&model.allDifferentiableVariables, along: 𝛁model) + optimizer.update(&model, along: 𝛁model) } XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]), [[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]]) diff --git a/Tests/TensorFlowTests/TrivialModelTests.swift b/Tests/TensorFlowTests/TrivialModelTests.swift index 3233b1911..83d55ead6 100644 --- a/Tests/TensorFlowTests/TrivialModelTests.swift +++ b/Tests/TensorFlowTests/TrivialModelTests.swift @@ -50,7 +50,7 @@ final class TrivialModelTests: XCTestCase { let ŷ = classifier(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model) + optimizer.update(&classifier, along: 𝛁model) } let ŷ = classifier.inferring(from: x) XCTAssertEqual(round(ŷ), y) From a34eb576a6f5b0002d05813d226da532705a4d18 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Thu, 27 Jun 2019 00:07:42 -0700 Subject: [PATCH 11/12] Make 'update(_:along:)' compatible with 'allDifferentiableVariables'. --- Sources/TensorFlow/Optimizer.swift | 32 +++++++++++++++++++-- Tests/TensorFlowTests/SequentialTests.swift | 12 ++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index badcadd9f..25c8f3016 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -40,6 +40,7 @@ fileprivate extension Tensor where Scalar: Numeric { /// https://arxiv.org/abs/1412.6980v8) public class Adam: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// A coefficient used to calculate the first and second moments of @@ -124,6 +125,11 @@ public class Adam: Optimizer sqrt(secondMoments[keyPath: kp]) + Double(epsilon) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } /// RMSProp optimizer. @@ -136,6 +142,7 @@ public class Adam: Optimizer /// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) public class RMSProp: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float // TODO: Document `rho`. Keras doesn't document `rho`. @@ -155,7 +162,7 @@ public class RMSProp: Optimizer rho: Float = 0.9, epsilon: Float = 1e-8, decay: Float = 0 - ) { + ) { precondition(learningRate >= 0, "Learning rate must be non-negative") precondition(rho >= 0, "Rho must be non-negative") precondition(decay >= 0, "Weight decay must be non-negative") @@ -191,6 +198,11 @@ public class RMSProp: Optimizer (sqrt(alpha[keyPath: kp]) + Double(epsilon)) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } /// Stochastic gradient descent (SGD) optimizer. @@ -198,7 +210,9 @@ public class RMSProp: Optimizer /// An optimizer that implements stochastic gradient descent, with support for momentum, learning /// rate decay, and Nesterov momentum. public class SGD: Optimizer - where Model.TangentVector: VectorProtocol & ElementaryFunctions, Model.TangentVector.VectorSpaceScalar == Float { + where Model.TangentVector: VectorProtocol & ElementaryFunctions, + Model.TangentVector.VectorSpaceScalar == Float { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// The momentum factor. It accelerates stochastic gradient descent in the relevant direction @@ -230,7 +244,8 @@ public class SGD: Optimizer self.nesterov = nesterov } - public func update(_ model: inout Model, along direction: Model.TangentVector) { + public func update(_ model: inout Model.AllDifferentiableVariables, + along direction: Model.TangentVector) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) velocity = momentum * velocity - direction * learningRate @@ -240,6 +255,11 @@ public class SGD: Optimizer model.move(along: velocity) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } // MARK: - Manifold optimizers @@ -279,6 +299,7 @@ public class RiemannSGD: Optimizer /// public class AdaGrad: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2, @@ -325,4 +346,9 @@ public class AdaGrad: Optimizer (sqrt(alpha[keyPath: kp] + Double(epsilon))) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index 7e27f76ad..182f44edb 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -29,7 +29,10 @@ final class SequentialTests: XCTestCase { } } var model = Model() - let optimizer = SGD(for: model, learningRate: 0.02) + let sgd = SGD(for: model, learningRate: 0.02) + let rmsprop = RMSProp(for: model, learningRate: 0.02) + let adam = Adam(for: model, learningRate: 0.02) + let adagrad = AdaGrad(for: model, learningRate: 0.02) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [0, 1, 1, 0] Context.local.learningPhase = .training @@ -38,10 +41,13 @@ final class SequentialTests: XCTestCase { let ŷ = model(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&model, along: 𝛁model) + sgd.update(&model, along: 𝛁model) + rmsprop.update(&model, along: 𝛁model) + adam.update(&model, along: 𝛁model) + adagrad.update(&model, along: 𝛁model) } XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]), - [[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]]) + [[0.5108122], [0.5108122], [0.5108122], [0.5108122]]) } static var allTests = [ From 7d5bfcf960869b18ab11145b290bdd837078182d Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Thu, 27 Jun 2019 00:11:02 -0700 Subject: [PATCH 12/12] More tests. --- Sources/TensorFlow/Optimizer.swift | 5 ++++- Tests/TensorFlowTests/SequentialTests.swift | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 25c8f3016..7679e5c66 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -94,7 +94,7 @@ public class Adam: Optimizer } } - + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.AllDifferentiableVariables) { step += 1 @@ -180,6 +180,7 @@ public class RMSProp: Optimizer } } + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { step += 1 @@ -244,6 +245,7 @@ public class SGD: Optimizer self.nesterov = nesterov } + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { step += 1 @@ -332,6 +334,7 @@ public class AdaGrad: Optimizer } } + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index 182f44edb..b09b5309b 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -42,12 +42,16 @@ final class SequentialTests: XCTestCase { return meanSquaredError(predicted: ŷ, expected: y) } sgd.update(&model, along: 𝛁model) + sgd.update(&model.allDifferentiableVariables, along: 𝛁model) rmsprop.update(&model, along: 𝛁model) + rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model) adam.update(&model, along: 𝛁model) + adam.update(&model.allDifferentiableVariables, along: 𝛁model) adagrad.update(&model, along: 𝛁model) + adagrad.update(&model.allDifferentiableVariables, along: 𝛁model) } XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]), - [[0.5108122], [0.5108122], [0.5108122], [0.5108122]]) + [[0.47705528], [0.47705528], [0.47705528], [0.47705528]]) } static var allTests = [