diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 7c4cf1d23..28116fe5f 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1314,3 +1314,44 @@ public struct Reshape: Layer { return input.reshaped(toShape: shape) } } + +/// An input to a recurrent neural network. +public struct RNNInput: Differentiable { + /// The input at the current time step. + public var timeStepInput: TimeStepInput + /// The previous state. + public var previousState: State + + @differentiable + public init(timeStepInput: TimeStepInput, previousState: State) { + self.timeStepInput = timeStepInput + self.previousState = previousState + } +} + +/// A recurrent neural network cell. +public protocol RNNCell: Layer where Input == RNNInput { + /// The input at a time step. + associatedtype TimeStepInput: Differentiable + /// The state that may be preserved across time steps. + typealias State = Output + /// The zero state. + var zeroState: State { get } +} + +public extension RNNCell { + /// Returns the new state obtained from applying the RNN cell to the input at the current time + /// step and the previous state. + /// + /// - Parameters: + /// - timeStepInput: The input at the current time step. + /// - previousState: The previous state of the RNN cell. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State { + return applied(to: Input(timeStepInput: timeStepInput, previousState: previous), + in: context) + } +}