diff --git a/Sources/TensorFlow/Loss.swift b/Sources/TensorFlow/Loss.swift index 9252bf74e..76e146a40 100644 --- a/Sources/TensorFlow/Loss.swift +++ b/Sources/TensorFlow/Loss.swift @@ -231,6 +231,7 @@ public func sigmoidCrossEntropy( // This numerical stable implementation is based on tf.nn.sigmoid_cross_entropy_with_logits. let maxLogitsWithZero = max(logits, Tensor(0)) - let loss = maxLogitsWithZero - logits * labels + log(1 + exp(-abs(logits))) + var loss = maxLogitsWithZero - logits * labels + loss = loss + log(1 + exp(-abs(logits))) return loss.mean() } diff --git a/Sources/TensorFlow/Operators/Math.swift b/Sources/TensorFlow/Operators/Math.swift index d6876e739..55098085a 100644 --- a/Sources/TensorFlow/Operators/Math.swift +++ b/Sources/TensorFlow/Operators/Math.swift @@ -161,22 +161,54 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint { // Vector Space //===------------------------------------------------------------------------------------------===// -extension Tensor: VectorProtocol where Scalar: Numeric { - public typealias VectorSpaceScalar = Scalar +extension Tensor: VectorProtocol where Scalar: TensorFlowFloatingPoint { + public typealias VectorSpaceScalar = Float - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func adding(_ scalar: Scalar) -> Self { - self + scalar + // @differentiable(where Scalar: TensorFlowFloatingPoint) + public func scaled(by scale: Float) -> Self { + Scalar(scale) * self } - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func subtracting(_ scalar: Scalar) -> Self { - self - scalar + // @differentiable(where Scalar: TensorFlowFloatingPoint) + public func adding(_ scalar: Float) -> Self { + self + Scalar(scalar) } - @differentiable(where Scalar: TensorFlowFloatingPoint) - public func scaled(by scalar: Scalar) -> Self { - self * scalar + // @differentiable(where Scalar: TensorFlowFloatingPoint) + public func subtracting(_ scalar: Float) -> Self { + self - Scalar(scalar) + } +} + +extension VectorProtocol { + static func + (lhs: VectorSpaceScalar, rhs: Self) -> Self { + rhs.adding(lhs) + } + + static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.adding(rhs) + } + + static func - (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.subtracting(rhs) + } + + static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self { + rhs.scaled(by: lhs) + } + + static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self { + lhs.scaled(by: rhs) + } +} + +extension VectorProtocol where VectorSpaceScalar: SignedNumeric { + static prefix func - (x: Self) -> Self { + .zero - x + } + + static func - (lhs: VectorSpaceScalar, rhs: Self) -> Self { + (-rhs).adding(lhs) } } diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index effeb5c3b..7679e5c66 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -25,8 +25,7 @@ public protocol Optimizer { var learningRate: Scalar { get set } /// Updates the specified differentiable variables along the specified /// direction. - mutating func update(_ variables: inout Model.AllDifferentiableVariables, - along direction: Model.TangentVector) + mutating func update(_ variables: inout Model, along direction: Model.TangentVector) } fileprivate extension Tensor where Scalar: Numeric { @@ -35,14 +34,13 @@ fileprivate extension Tensor where Scalar: Numeric { } } -// MARK: - Key-path based optimizers - /// Adam optimizer. /// /// Reference: ["Adam - A Method for Stochastic Optimization"]( /// https://arxiv.org/abs/1412.6980v8) public class Adam: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// A coefficient used to calculate the first and second moments of @@ -96,7 +94,7 @@ public class Adam: Optimizer } } - + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.AllDifferentiableVariables) { step += 1 @@ -127,6 +125,11 @@ public class Adam: Optimizer sqrt(secondMoments[keyPath: kp]) + Double(epsilon) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } /// RMSProp optimizer. @@ -139,6 +142,7 @@ public class Adam: Optimizer /// http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) public class RMSProp: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float // TODO: Document `rho`. Keras doesn't document `rho`. @@ -176,7 +180,7 @@ public class RMSProp: Optimizer } } - + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { step += 1 @@ -195,14 +199,21 @@ public class RMSProp: Optimizer (sqrt(alpha[keyPath: kp]) + Double(epsilon)) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } /// Stochastic gradient descent (SGD) optimizer. /// /// An optimizer that implements stochastic gradient descent, with support for momentum, learning /// rate decay, and Nesterov momentum. -public class SGD: Optimizer - where Model.AllDifferentiableVariables == Model.TangentVector { +public class SGD: Optimizer + where Model.TangentVector: VectorProtocol & ElementaryFunctions, + Model.TangentVector.VectorSpaceScalar == Float { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// The momentum factor. It accelerates stochastic gradient descent in the relevant direction @@ -212,8 +223,8 @@ public class SGD: Optimizer public var decay: Float /// Use Nesterov momentum if true. public var nesterov: Bool - /// The velocity state of the model - public var velocity: Model.AllDifferentiableVariables + /// The velocity state of the model. + public var velocity: Model.TangentVector = .zero /// The set of steps taken. public var step: Int = 0 @@ -232,53 +243,38 @@ public class SGD: Optimizer self.momentum = momentum self.decay = decay self.nesterov = nesterov - velocity = model.allDifferentiableVariables - for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp].resetToZero() - } - for kp in velocity.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp].resetToZero() - } } + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { step += 1 let learningRate = self.learningRate * 1 / (1 + decay * Float(step)) - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp] = - momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp] - if nesterov { - model[keyPath: kp] += - momentum * velocity[keyPath: kp] - learningRate * direction[keyPath: kp] - } else { - model[keyPath: kp] += velocity[keyPath: kp] - } - } - for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { - velocity[keyPath: kp] = - Double(momentum) * velocity[keyPath: kp] - - Double(learningRate) * direction[keyPath: kp] - if nesterov { - model[keyPath: kp] += - Double(momentum) * velocity[keyPath: kp] - Double(learningRate) * - direction[keyPath: kp] - } else { - model[keyPath: kp] += velocity[keyPath: kp] - } + velocity = momentum * velocity - direction * learningRate + if nesterov { + model.move(along: momentum * velocity - direction * learningRate) + } else { + model.move(along: velocity) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } // MARK: - Manifold optimizers /// A Riemann manifold stochastic gradient descent (SGD) optimizer. -public class RiemannSGD: Optimizer - where Model.TangentVector: VectorProtocol, Model.TangentVector.VectorSpaceScalar == Scalar { +public class RiemannSGD: Optimizer + where Model.TangentVector: VectorProtocol, + Model.TangentVector.VectorSpaceScalar: FloatingPoint { + public typealias Scalar = Model.TangentVector.VectorSpaceScalar /// The learning rate. - public var learningRate: Scalar + public var learningRate: Model.TangentVector.VectorSpaceScalar - public init(learningRate: Scalar) { + public init(learningRate: Model.TangentVector.VectorSpaceScalar) { self.learningRate = learningRate } @@ -305,6 +301,7 @@ public class RiemannSGD: Optimizer /// public class AdaGrad: Optimizer where Model.AllDifferentiableVariables == Model.TangentVector { + public typealias Model = Model /// The learning rate. public var learningRate: Float /// The smoothing factor (ρ). Typical values are `0.5`, `0.9`, and `0.99`, for smoothing over 2, @@ -337,6 +334,7 @@ public class AdaGrad: Optimizer } } + // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed. public func update(_ model: inout Model.AllDifferentiableVariables, along direction: Model.TangentVector) { for kp in model.recursivelyAllWritableKeyPaths(to: Tensor.self) { @@ -351,4 +349,9 @@ public class AdaGrad: Optimizer (sqrt(alpha[keyPath: kp] + Double(epsilon))) } } + + public func update(_ model: inout Model, + along direction: Model.TangentVector) { + update(&model.allDifferentiableVariables, along: direction) + } } diff --git a/Tests/TensorFlowTests/SequentialTests.swift b/Tests/TensorFlowTests/SequentialTests.swift index 8d305e9a1..b09b5309b 100644 --- a/Tests/TensorFlowTests/SequentialTests.swift +++ b/Tests/TensorFlowTests/SequentialTests.swift @@ -29,7 +29,10 @@ final class SequentialTests: XCTestCase { } } var model = Model() - let optimizer = SGD(for: model, learningRate: 0.02) + let sgd = SGD(for: model, learningRate: 0.02) + let rmsprop = RMSProp(for: model, learningRate: 0.02) + let adam = Adam(for: model, learningRate: 0.02) + let adagrad = AdaGrad(for: model, learningRate: 0.02) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [0, 1, 1, 0] Context.local.learningPhase = .training @@ -38,10 +41,17 @@ final class SequentialTests: XCTestCase { let ŷ = model(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&model.allDifferentiableVariables, along: 𝛁model) + sgd.update(&model, along: 𝛁model) + sgd.update(&model.allDifferentiableVariables, along: 𝛁model) + rmsprop.update(&model, along: 𝛁model) + rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model) + adam.update(&model, along: 𝛁model) + adam.update(&model.allDifferentiableVariables, along: 𝛁model) + adagrad.update(&model, along: 𝛁model) + adagrad.update(&model.allDifferentiableVariables, along: 𝛁model) } XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]), - [[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]]) + [[0.47705528], [0.47705528], [0.47705528], [0.47705528]]) } static var allTests = [ diff --git a/Tests/TensorFlowTests/TrivialModelTests.swift b/Tests/TensorFlowTests/TrivialModelTests.swift index 3233b1911..83d55ead6 100644 --- a/Tests/TensorFlowTests/TrivialModelTests.swift +++ b/Tests/TensorFlowTests/TrivialModelTests.swift @@ -50,7 +50,7 @@ final class TrivialModelTests: XCTestCase { let ŷ = classifier(x) return meanSquaredError(predicted: ŷ, expected: y) } - optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model) + optimizer.update(&classifier, along: 𝛁model) } let ŷ = classifier.inferring(from: x) XCTAssertEqual(round(ŷ), y)