diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index ee68ee9ca..4637ba715 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1281,7 +1281,7 @@ public extension RNNCell { } /// A Simple RNN Cell. -public struct SimpleRNNCell: RNNCell { +public struct SimpleRNNCell: RNNCell, VectorNumeric { public var weight: Tensor public var bias: Tensor @@ -1304,9 +1304,13 @@ public struct SimpleRNNCell: RNNCell { /// - Parameters: /// - inputSize: The number of features in 2-D input tensors. /// - hiddenSize: The number of features in 2-D hidden states. - public init(inputSize: Int, hiddenSize: Int) { + /// - seed: The random seed for initialization. The default value is random. + public init(inputSize: Int, hiddenSize: Int, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..: RNNCell { } /// An LSTM Cell. -public struct LSTMCell: RNNCell { +public struct LSTMCell: RNNCell, VectorNumeric { public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor public var inputBias, updateBias, forgetBias, outputBias: Tensor @@ -1348,17 +1352,19 @@ public struct LSTMCell: RNNCell { /// - Parameters: /// - inputSize: The number of features in 2-D input tensors. /// - hiddenSize: The number of features in 2-D hidden states. - public init(inputSize: Int, hiddenSize: Int) { + public init(inputSize: Int, hiddenSize: Int, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..: RNNCell { return Output(output: newState, state: newState) } } + +public struct RNN: Layer { + public typealias Input = [Cell.TimeStepInput] + public typealias Output = [Cell.TimeStepOutput] + + public var cell: Cell + + public init(_ cell: @autoclosure () -> Cell) { + self.cell = cell() + } + + @differentiable(wrt: (self, input), vjp: _vjpCall(_:initialState:)) + public func call(_ input: [Cell.TimeStepInput], + initialState: Cell.State) -> [Cell.TimeStepOutput] { + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + for timestep in input { + let output = cell(input: timestep, state: currentHiddenState) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + } + return timeStepOutputs + } + + @usableFromInline + internal func _vjpCall( + _ inputs: [Cell.TimeStepInput], initialState: Cell.State + ) -> ([Cell.TimeStepOutput], + (Array.CotangentVector) + -> (CotangentVector, Array.CotangentVector)) { + let timeStepCount = inputs.count + var currentHiddenState = cell.zeroState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + timeStepOutputs.reserveCapacity(timeStepCount) + var backpropagators: [Cell.Backpropagator] = [] + backpropagators.reserveCapacity(timeStepCount) + for timestep in inputs { + let (output, backpropagator) = + cell.appliedForBackpropagation(to: .init(input: timestep, + state: currentHiddenState)) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + backpropagators.append(backpropagator) + } + return (timeStepOutputs, { 𝛁outputs in + precondition(𝛁outputs.base.count == timeStepCount, + "The number of output gradients must equal the number of time steps") + var 𝛁cell = Cell.CotangentVector.zero + var 𝛁state = Cell.State.CotangentVector.zero + var reversed𝛁inputs: [Cell.TimeStepInput.CotangentVector] = [] + reversed𝛁inputs.reserveCapacity(timeStepCount) + for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() { + let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state)) + 𝛁cell = new𝛁cell + 𝛁state = 𝛁input.state + reversed𝛁inputs.append(𝛁input.input) + } + return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed()))) + }) + } + + @differentiable(wrt: (self, inputs)) + public func call(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] { + return self(inputs, initialState: cell.zeroState.withoutDerivative()) + } + + /* TODO: Uncomment once control flow and differentiation through force unwrapping is supported. + @differentiable(wrt: (self, inputs)) + public func lastOutput(from inputs: [Cell.TimeStepInput], + initialState: Cell.State) -> Cell.TimeStepOutput { + precondition(!inputs.isEmpty, "inputs cannot be empty") + return self(inputs, initialState: initialState).last! + } + + @differentiable(wrt: (self, inputs)) + public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput { + precondition(!inputs.isEmpty, "inputs cannot be empty") + return self(inputs, initialState: cell.zeroState).last! + } + */ +} + +extension RNN: Equatable where Cell: Equatable {} +extension RNN: AdditiveArithmetic where Cell: AdditiveArithmetic {} +extension RNN: VectorNumeric where Cell: VectorNumeric {} + +public typealias SimpleRNN = RNN> +public typealias LSTM = RNN> diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift index 5bd8ba179..a347c4d7e 100644 --- a/Tests/DeepLearningTests/LayerTests.swift +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -95,6 +95,31 @@ final class LayerTests: XCTestCase { XCTAssertEqual(output, expected) } + func testRNN() { + let x = Tensor(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted() + let inputs: [Tensor] = Array(repeating: x, count: 4) + let rnn = RNN(SimpleRNNCell(inputSize: 4, hiddenSize: 4, + seed: (0xFeedBeef, 0xDeadBeef))) + let (outputs, pullback) = rnn.valueWithPullback(at: inputs) { rnn, inputs in + return rnn(inputs) + } + XCTAssertEqual(outputs, [[[-0.0026294366, -0.0058668107, 0.04495003, 0.20311214]], + [[ 0.06788494, 0.050665878, 0.02415526, 0.09249911]], + [[ 0.06621192, 0.009049267, 0.065047316, 0.11534518]], + [[ 0.05612204, 0.00022032857, 0.05407162, 0.09784105]]]) + let (𝛁rnn, 𝛁inputs) = pullback(.init(inputs)) + XCTAssertEqual(𝛁rnn.cell.weight, + [[ 0.0, 0.0, 0.0, 0.0], + [-0.0051278225, 0.0013102926, 0.00740262, 0.018119661], + [ -0.010255645, 0.0026205853, 0.01480524, 0.036239322], + [ -0.015383467, 0.003930878, 0.02220786, 0.054358985], + [ 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 0.0], + [ 0.0, 0.0, 0.0, 0.0]]) + XCTAssertEqual(𝛁rnn.cell.bias, [-0.051278222, 0.013102926, 0.0740262, 0.18119662]) + } + static var allTests = [ ("testConv1D", testConv1D), ("testMaxPool1D", testMaxPool1D), @@ -104,6 +129,7 @@ final class LayerTests: XCTestCase { ("testGlobalAvgPool3D", testGlobalAvgPool3D), ("testReshape", testReshape), ("testFlatten", testFlatten), - ("testSimpleRNNCell", testSimpleRNNCell) + ("testSimpleRNNCell", testSimpleRNNCell), + ("testRNN", testRNN) ] }