From dc873691954e7050f1a13750aac54021f4f5bde9 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 7 Apr 2019 06:01:52 +0100 Subject: [PATCH 1/4] Add `RNNCell` protocol. --- Sources/DeepLearning/Layer.swift | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 7c4cf1d23..b6559049a 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1314,3 +1314,44 @@ public struct Reshape: Layer { return input.reshaped(toShape: shape) } } + +/// An input to a recurrent neural network. +public struct RNNInput: Differentiable { + /// The input at the current time step. + public var timeStepInput: TimeStepInput + /// The previous state. + public var previousState: State + + @differentiable + public init(timeStepInput: StepInput, previousState: State) { + self.timeStepInput = timeStepInput + self.previousState = previousState + } +} + +/// A recurrent neural network cell. +public protocol RNNCell: Layer where Input == RNNInput { + /// The input at a time step. + associatedtype TimeStepInput: Differentiable + /// The state that may be preserved across time steps. + typealias State = Output + /// The zero state. + var zeroState: State { get } +} + +public extension RNNCell { + /// Returns the new state obtained from applying the RNN cell to the input at the current time + /// step and the previous state. + /// + /// - Parameters: + /// - timeStepInput: The input at the current time step. + /// - previousState: The previous state of the RNN cell. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + func applied(to timeStepInput: StepInput, previous: State, in context: Context) -> State { + return applied(to: Input(timeStepInput: timeStepInput, previousState: previousState), + in: context) + } +} From 3e5703ff14c271660d88028adc714d5ff71ca5c3 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 7 Apr 2019 06:12:22 +0100 Subject: [PATCH 2/4] Fix build. --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index b6559049a..2dbe262a8 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1351,7 +1351,7 @@ public extension RNNCell { /// - Returns: The output. @differentiable func applied(to timeStepInput: StepInput, previous: State, in context: Context) -> State { - return applied(to: Input(timeStepInput: timeStepInput, previousState: previousState), + return applied(to: Input(timeStepInput: timeStepInput, previousState: previous), in: context) } } From 4d32dbf54cbd7143517d64abc2f73ab928aa1c1e Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 7 Apr 2019 06:30:01 +0100 Subject: [PATCH 3/4] Fix build. --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 2dbe262a8..335d1ad57 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1350,7 +1350,7 @@ public extension RNNCell { /// phase. /// - Returns: The output. @differentiable - func applied(to timeStepInput: StepInput, previous: State, in context: Context) -> State { + func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State { return applied(to: Input(timeStepInput: timeStepInput, previousState: previous), in: context) } From 87f8260bde2528ddb6fcd22178f625c32401b838 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 7 Apr 2019 06:47:34 +0100 Subject: [PATCH 4/4] Update Layer.swift --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 335d1ad57..28116fe5f 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1323,7 +1323,7 @@ public struct RNNInput: Di public var previousState: State @differentiable - public init(timeStepInput: StepInput, previousState: State) { + public init(timeStepInput: TimeStepInput, previousState: State) { self.timeStepInput = timeStepInput self.previousState = previousState }