Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 103 additions & 9 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ public extension RNNCell {
}

/// A Simple RNN Cell.
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>

Expand All @@ -1304,9 +1304,13 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
/// - Parameters:
/// - inputSize: The number of features in 2-D input tensors.
/// - hiddenSize: The number of features in 2-D hidden states.
public init(inputSize: Int, hiddenSize: Int) {
/// - seed: The random seed for initialization. The default value is random.
public init(inputSize: Int, hiddenSize: Int,
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
Int64.random(in: Int64.min..<Int64.max))) {
let concatenatedInputSize = inputSize + hiddenSize
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize],
seed: seed)
self.bias = Tensor(zeros: [hiddenSize])
}

Expand All @@ -1326,7 +1330,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
}

/// An LSTM Cell.
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric {
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>

Expand All @@ -1348,17 +1352,19 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
/// - Parameters:
/// - inputSize: The number of features in 2-D input tensors.
/// - hiddenSize: The number of features in 2-D hidden states.
public init(inputSize: Int, hiddenSize: Int) {
public init(inputSize: Int, hiddenSize: Int,
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
Int64.random(in: Int64.min..<Int64.max))) {
let concatenatedInputSize = inputSize + hiddenSize
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
let gateBiasShape = TensorShape([hiddenSize])
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
self.inputWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
self.inputBias = Tensor(zeros: gateBiasShape)
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
self.updateWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
self.updateBias = Tensor(zeros: gateBiasShape)
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
self.forgetWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
self.forgetBias = Tensor(ones: gateBiasShape)
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
self.outputWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
self.outputBias = Tensor(zeros: gateBiasShape)
}

Expand Down Expand Up @@ -1397,3 +1403,91 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
return Output(output: newState, state: newState)
}
}

public struct RNN<Cell: RNNCell>: Layer {
public typealias Input = [Cell.TimeStepInput]
public typealias Output = [Cell.TimeStepOutput]

public var cell: Cell

public init(_ cell: @autoclosure () -> Cell) {
self.cell = cell()
}

@differentiable(wrt: (self, input), vjp: _vjpCall(_:initialState:))
public func call(_ input: [Cell.TimeStepInput],
initialState: Cell.State) -> [Cell.TimeStepOutput] {
var currentHiddenState = initialState
var timeStepOutputs: [Cell.TimeStepOutput] = []
for timestep in input {
let output = cell(input: timestep, state: currentHiddenState)
currentHiddenState = output.state
timeStepOutputs.append(output.output)
}
return timeStepOutputs
}

@usableFromInline
internal func _vjpCall(
_ inputs: [Cell.TimeStepInput], initialState: Cell.State
) -> ([Cell.TimeStepOutput],
(Array<Cell.TimeStepOutput>.CotangentVector)
-> (CotangentVector, Array<Cell.TimeStepInput>.CotangentVector)) {
let timeStepCount = inputs.count
var currentHiddenState = cell.zeroState
var timeStepOutputs: [Cell.TimeStepOutput] = []
timeStepOutputs.reserveCapacity(timeStepCount)
var backpropagators: [Cell.Backpropagator] = []
backpropagators.reserveCapacity(timeStepCount)
for timestep in inputs {
let (output, backpropagator) =
cell.appliedForBackpropagation(to: .init(input: timestep,
state: currentHiddenState))
currentHiddenState = output.state
timeStepOutputs.append(output.output)
backpropagators.append(backpropagator)
}
return (timeStepOutputs, { 𝛁outputs in
precondition(𝛁outputs.base.count == timeStepCount,
"The number of output gradients must equal the number of time steps")
var 𝛁cell = Cell.CotangentVector.zero
var 𝛁state = Cell.State.CotangentVector.zero
var reversed𝛁inputs: [Cell.TimeStepInput.CotangentVector] = []
reversed𝛁inputs.reserveCapacity(timeStepCount)
for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() {
let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state))
𝛁cell = new𝛁cell
𝛁state = 𝛁input.state
reversed𝛁inputs.append(𝛁input.input)
}
return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Regarding this array reversal, while I could've zero-initialized an array of timeStepCount tensors and modified them in reverse order, it would be less efficient because of the cost of heap-allocating timeStepCount extra tensors.

})
}

@differentiable(wrt: (self, inputs))
public func call(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
return self(inputs, initialState: cell.zeroState.withoutDerivative())
}

/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
@differentiable(wrt: (self, inputs))
public func lastOutput(from inputs: [Cell.TimeStepInput],
initialState: Cell.State) -> Cell.TimeStepOutput {
precondition(!inputs.isEmpty, "inputs cannot be empty")
return self(inputs, initialState: initialState).last!
}

@differentiable(wrt: (self, inputs))
public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput {
precondition(!inputs.isEmpty, "inputs cannot be empty")
return self(inputs, initialState: cell.zeroState).last!
}
*/
}

extension RNN: Equatable where Cell: Equatable {}
extension RNN: AdditiveArithmetic where Cell: AdditiveArithmetic {}
extension RNN: VectorNumeric where Cell: VectorNumeric {}

public typealias SimpleRNN<Scalar: TensorFlowFloatingPoint> = RNN<SimpleRNNCell<Scalar>>
public typealias LSTM<Scalar: TensorFlowFloatingPoint> = RNN<LSTMCell<Scalar>>
28 changes: 27 additions & 1 deletion Tests/DeepLearningTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,31 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testRNN() {
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4,
seed: (0xFeedBeef, 0xDeadBeef)))
let (outputs, pullback) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
return rnn(inputs)
}
XCTAssertEqual(outputs, [[[-0.0026294366, -0.0058668107, 0.04495003, 0.20311214]],
[[ 0.06788494, 0.050665878, 0.02415526, 0.09249911]],
[[ 0.06621192, 0.009049267, 0.065047316, 0.11534518]],
[[ 0.05612204, 0.00022032857, 0.05407162, 0.09784105]]])
let (𝛁rnn, 𝛁inputs) = pullback(.init(inputs))
XCTAssertEqual(𝛁rnn.cell.weight,
[[ 0.0, 0.0, 0.0, 0.0],
[-0.0051278225, 0.0013102926, 0.00740262, 0.018119661],
[ -0.010255645, 0.0026205853, 0.01480524, 0.036239322],
[ -0.015383467, 0.003930878, 0.02220786, 0.054358985],
[ 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 0.0]])
XCTAssertEqual(𝛁rnn.cell.bias, [-0.051278222, 0.013102926, 0.0740262, 0.18119662])
}

static var allTests = [
("testConv1D", testConv1D),
("testMaxPool1D", testMaxPool1D),
Expand All @@ -104,6 +129,7 @@ final class LayerTests: XCTestCase {
("testGlobalAvgPool3D", testGlobalAvgPool3D),
("testReshape", testReshape),
("testFlatten", testFlatten),
("testSimpleRNNCell", testSimpleRNNCell)
("testSimpleRNNCell", testSimpleRNNCell),
("testRNN", testRNN)
]
}