Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Fixed a couple RNN bugs. #522

Merged
merged 9 commits into from Nov 14, 2019
73 changes: 37 additions & 36 deletions Sources/TensorFlow/Layers/Recurrent.swift
Expand Up @@ -56,8 +56,9 @@ public protocol RNNCell: Layer
associatedtype TimeStepOutput: Differentiable
/// The state that may be preserved across time steps.
associatedtype State: Differentiable
/// The zero state.
var zeroState: State { get }

/// Returns a zero-valued state with shape compatible with the provided input.
func zeroState(for input: TimeStepInput) -> State
}

public extension RNNCell {
Expand Down Expand Up @@ -87,14 +88,6 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>

@noDerivative public var stateShape: TensorShape {
TensorShape([1, weight.shape[1]])
}

public var zeroState: State {
State(Tensor(zeros: stateShape))
}

// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after SR-10697 is fixed.
public struct State: Equatable, Differentiable, VectorProtocol, KeyPathIterable {
public var value: Tensor<Scalar>
Expand All @@ -120,6 +113,11 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
self.bias = Tensor(zeros: [hiddenSize])
}

/// Returns a zero-valued state with shape compatible with the provided input.
public func zeroState(for input: Tensor<Scalar>) -> State {
State(Tensor(zeros: [input.shape[0], weight.shape[1]]))
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
Expand Down Expand Up @@ -185,14 +183,6 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
return fusedBias.slice(lowerBounds: [3 * hiddenSize], upperBounds: [4 * hiddenSize])
}

@noDerivative public var stateShape: TensorShape {
TensorShape([1, fusedWeight.shape[1] / 4])
}

public var zeroState: State {
State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
}

public typealias TimeStepInput = Tensor<Scalar>
public typealias TimeStepOutput = State
public typealias Input = RNNCellInput<TimeStepInput, State>
Expand All @@ -219,6 +209,14 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
}
}

/// Returns a zero-valued state with shape compatible with the provided input.
public func zeroState(for input: Tensor<Scalar>) -> State {
let hiddenSize = fusedWeight.shape[1] / 4
return State(
cell: Tensor(zeros: [input.shape[0], hiddenSize]),
hidden: Tensor(zeros: [input.shape[0], hiddenSize]))
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
Expand Down Expand Up @@ -268,27 +266,28 @@ public struct RNN<Cell: RNNCell>: Layer {
self.cell = cell()
}

@differentiable(wrt: (self, input), vjp: _vjpCallAsFunction(_:initialState:))
@differentiable(wrt: (self, inputs), vjp: _vjpCallAsFunction(_:initialState:))
public func callAsFunction(
_ input: [Cell.TimeStepInput],
_ inputs: [Cell.TimeStepInput],
initialState: Cell.State
) -> [Cell.TimeStepOutput] {
if inputs.isEmpty { return [Cell.TimeStepOutput]() }
var currentHiddenState = initialState
var timeStepOutputs: [Cell.TimeStepOutput] = []
for timestep in input {
let output = cell(input: timestep, state: currentHiddenState)
for timeStepInput in inputs {
let output = cell(input: timeStepInput, state: currentHiddenState)
currentHiddenState = output.state
timeStepOutputs.append(output.output)
}
return timeStepOutputs
}

@differentiable(wrt: (self, input))
@differentiable(wrt: (self, inputs))
public func call(
_ input: [Cell.TimeStepInput],
_ inputs: [Cell.TimeStepInput],
initialState: Cell.State
) -> [Cell.TimeStepOutput] {
callAsFunction(input, initialState: initialState)
callAsFunction(inputs, initialState: initialState)
}

@usableFromInline
Expand All @@ -299,7 +298,7 @@ public struct RNN<Cell: RNNCell>: Layer {
(Array<Cell.TimeStepOutput>.TangentVector)
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)) {
let timeStepCount = inputs.count
var currentHiddenState = cell.zeroState
var currentHiddenState = cell.zeroState(for: inputs[0])
var timeStepOutputs: [Cell.TimeStepOutput] = []
timeStepOutputs.reserveCapacity(timeStepCount)
var backpropagators: [Cell.Backpropagator] = []
Expand All @@ -320,7 +319,7 @@ public struct RNN<Cell: RNNCell>: Layer {
reversed𝛁inputs.reserveCapacity(timeStepCount)
for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() {
let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state))
𝛁cell = new𝛁cell
𝛁cell += new𝛁cell
𝛁state = 𝛁input.state
reversed𝛁inputs.append(𝛁input.input)
}
Expand All @@ -330,23 +329,25 @@ public struct RNN<Cell: RNNCell>: Layer {

@differentiable
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
return self(inputs, initialState: withoutDerivative(at: cell.zeroState))
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
return self(inputs, initialState: withoutDerivative(at: initialState))
}

/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
@differentiable(wrt: (self, inputs))
public func lastOutput(from inputs: [Cell.TimeStepInput],
initialState: Cell.State) -> Cell.TimeStepOutput {
precondition(!inputs.isEmpty, "inputs cannot be empty")
return self(inputs, initialState: initialState).last!
public func lastOutput(
from inputs: [Cell.TimeStepInput],
initialState: Cell.State
) -> Cell.TimeStepOutput {
precondition(!inputs.isEmpty, "'inputs' must be non-empty.")
return self(inputs, initialState: initialState)[withoutDerivative(at: inputs.count - 1)]
}

@differentiable(wrt: (self, inputs))
public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput {
precondition(!inputs.isEmpty, "inputs cannot be empty")
return self(inputs, initialState: cell.zeroState).last!
precondition(!inputs.isEmpty, "'inputs' must be non-empty.")
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
return lastOutput(from: inputs, initialState: initialState)
}
*/
}

extension RNN: Equatable where Cell: Equatable {}
Expand Down