diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 8348d5889..7f09c27c7 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1276,7 +1276,6 @@ public protocol RNNCell: Layer where Input == RNNCellInput /// The state that may be preserved across time steps. associatedtype State: Differentiable /// The zero state. - @differentiable var zeroState: State { get } } @@ -1293,3 +1292,121 @@ public extension RNNCell { return applied(to: RNNCellInput(input: input, state: state)) } } + +/// A Simple RNN Cell. +public struct SimpleRNNCell: RNNCell { + public var weight: Tensor + public var bias: Tensor + + @noDerivative public var stateShape: TensorShape { + return TensorShape([1, weight.shape[1]]) + } + + public var zeroState: Tensor { + return Tensor(zeros: stateShape) + } + + public typealias State = Tensor + public typealias TimeStepInput = Tensor + public typealias TimeStepOutput = State + public typealias Input = RNNCellInput + public typealias Output = RNNCellOutput + + /// Creates a `SimpleRNNCell` with the specified input size and hidden state size. + /// + /// - Parameters: + /// - inputSize: The number of features in 2-D input tensors. + /// - hiddenSize: The number of features in 2-D hidden states. + public init(inputSize: Int, hiddenSize: Int) { + let concatenatedInputSize = inputSize + hiddenSize + self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize]) + self.bias = Tensor(zeros: [hiddenSize]) + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The hidden state. + @differentiable + public func applied(to input: Input) -> Output { + let concatenatedInput = input.input.concatenated(with: input.state, alongAxis: 1) + let newState = matmul(concatenatedInput, weight) + bias + return Output(output: newState, state: newState) + } +} + +/// An LSTM Cell. +public struct LSTMCell: RNNCell { + public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor + public var inputBias, updateBias, forgetBias, outputBias: Tensor + + @noDerivative public var stateShape: TensorShape { + return TensorShape([1, inputWeight.shape[1]]) + } + + public var zeroState: State { + return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape)) + } + + public typealias TimeStepInput = Tensor + public typealias TimeStepOutput = State + public typealias Input = RNNCellInput + public typealias Output = RNNCellOutput + + /// Creates a `LSTMCell` with the specified input size and hidden state size. + /// + /// - Parameters: + /// - inputSize: The number of features in 2-D input tensors. + /// - hiddenSize: The number of features in 2-D hidden states. + public init(inputSize: Int, hiddenSize: Int) { + let concatenatedInputSize = inputSize + hiddenSize + let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize]) + let gateBiasShape = TensorShape([hiddenSize]) + self.inputWeight = Tensor(glorotUniform: gateWeightShape) + self.inputBias = Tensor(zeros: gateBiasShape) + self.updateWeight = Tensor(glorotUniform: gateWeightShape) + self.updateBias = Tensor(zeros: gateBiasShape) + self.forgetWeight = Tensor(glorotUniform: gateWeightShape) + self.forgetBias = Tensor(ones: gateBiasShape) + self.outputWeight = Tensor(glorotUniform: gateWeightShape) + self.outputBias = Tensor(zeros: gateBiasShape) + } + + public struct State: Differentiable { + public var cell: Tensor + public var hidden: Tensor + + @differentiable + public init(cell: Tensor, hidden: Tensor) { + self.cell = cell + self.hidden = hidden + } + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The hidden state. + @differentiable + public func applied(to input: Input) -> Output { + let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1) + + let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias) + let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias) + let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias) + let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias) + + let newCellState = input.state.cell * forgetGate + inputGate * updateGate + let newHiddenState = tanh(newCellState) * outputGate + + let newState = State(cell: newCellState, hidden: newHiddenState) + + return Output(output: newState, state: newState) + } +} diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift index 1de95661a..a36790db4 100644 --- a/Tests/DeepLearningTests/LayerTests.swift +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -82,6 +82,19 @@ final class LayerTests: XCTestCase { XCTAssertEqual(output.shape, expected) } + func testSimpleRNNCell() { + let weight = Tensor(ones: [7, 5]) * Tensor([0.3333, 1, 0.3333, 1, 0.3333]) + let bias = Tensor(ones: [5]) + var cell = SimpleRNNCell(inputSize: 2, hiddenSize: 5) + cell.weight = weight + cell.bias = bias + let state = Tensor(ones: [1, 5]) * Tensor([1, 0.2, 0.5, 2, 0.6]) + let input = Tensor(ones: [1, 2]) * Tensor([0.3, 0.7]) + let output = cell.applied(to: input, state: state).state + let expected = Tensor([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]]) + XCTAssertEqual(output, expected) + } + static var allTests = [ ("testConv1D", testConv1D), ("testMaxPool1D", testMaxPool1D), @@ -90,6 +103,7 @@ final class LayerTests: XCTestCase { ("testGlobalAvgPool2D", testGlobalAvgPool2D), ("testGlobalAvgPool3D", testGlobalAvgPool3D), ("testReshape", testReshape), - ("testFlatten", testFlatten) + ("testFlatten", testFlatten), + ("testSimpleRNNCell", testSimpleRNNCell) ] }