From 2095c8356dd192feba89c64e07cf998dad4d7ddd Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Fri, 19 Apr 2019 15:05:30 +0530 Subject: [PATCH 1/8] Breaking down Layers into a directory --- .../DeepLearning/Layers/Convolutional.swift | 380 ++++++++++++++++++ Sources/DeepLearning/Layers/Core.swift | 239 +++++++++++ Sources/DeepLearning/Layers/Layer.swift | 192 +++++++++ .../DeepLearning/Layers/Normalization.swift | 202 ++++++++++ Sources/DeepLearning/Layers/Pooling.Swift | 240 +++++++++++ Sources/DeepLearning/Layers/Recurrent.Swift | 192 +++++++++ Sources/DeepLearning/Layers/Upsampling.swift | 72 ++++ 7 files changed, 1517 insertions(+) create mode 100644 Sources/DeepLearning/Layers/Convolutional.swift create mode 100644 Sources/DeepLearning/Layers/Core.swift create mode 100644 Sources/DeepLearning/Layers/Layer.swift create mode 100644 Sources/DeepLearning/Layers/Normalization.swift create mode 100644 Sources/DeepLearning/Layers/Pooling.Swift create mode 100644 Sources/DeepLearning/Layers/Recurrent.Swift create mode 100644 Sources/DeepLearning/Layers/Upsampling.swift diff --git a/Sources/DeepLearning/Layers/Convolutional.swift b/Sources/DeepLearning/Layers/Convolutional.swift new file mode 100644 index 000000000..ad0f2bb59 --- /dev/null +++ b/Sources/DeepLearning/Layers/Convolutional.swift @@ -0,0 +1,380 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +@_exported import TensorFlow +#endif + + +/// A 1-D convolution layer (e.g. temporal convolution over a time-series). +/// +/// This layer creates a convolution filter that is convolved with the layer input to produce a +/// tensor of outputs. +@_fixed_layout +public struct Conv1D: Layer { + /// The 3-D convolution kernel `[width, inputChannels, outputChannels]`. + public var filter: Tensor + /// The bias vector `[outputChannels]`. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The stride of the sliding window for temporal dimension. + @noDerivative public let stride: Int + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// Creates a `Conv1D` layer with the specified filter, bias, activation function, stride, and + /// padding. + /// + /// - Parameters: + /// - filter: The 3-D convolution kernel `[width, inputChannels, outputChannels]`. + /// - bias: The bias vector `[outputChannels]`. + /// - activation: The element-wise activation function. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + stride: Int, + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.stride = stride + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer `[batchCount, width, inputChannels]`. + /// - Returns: The output `[batchCount, newWidth, outputChannels]`. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let conv2D = input.expandingShape(at: 1).convolved2D( + withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding) + return activation(conv2D.squeezingShape(at: 1) + bias) + } +} + +public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { + /// Creates a `Conv1D` layer with the specified filter shape, stride, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:stride:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + filterShape.0, filterShape.1, filterShape.2]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape), + bias: Tensor(zeros: TensorShape([filterShape.2])), + activation: activation, + stride: stride, + padding: padding) + } +} + +public extension Conv1D { + /// Creates a `Conv1D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..: Layer { + /// The 4-D convolution kernel. + public var filter: Tensor + /// The bias vector. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The strides of the sliding window for spatial dimensions. + @noDerivative public let strides: (Int, Int) + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// Creates a `Conv2D` layer with the specified filter, bias, activation function, strides, and + /// padding. + /// + /// - Parameters: + /// - filter: The 4-D convolution kernel. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + strides: (Int, Int), + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.strides = strides + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return activation(input.convolved2D(withFilter: filter, + strides: (1, strides.0, strides.1, 1), + padding: padding) + bias) + } +} + +public extension Conv2D { + /// Creates a `Conv2D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int, Int), + strides: (Int, Int) = (1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + filterShape.0, filterShape.1, filterShape.2, filterShape.3]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape, generator: &generator), + bias: Tensor(zeros: TensorShape([filterShape.3])), + activation: activation, + strides: strides, + padding: padding) + } +} + +public extension Conv2D { + /// Creates a `Conv2D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int, Int), + strides: (Int, Int) = (1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min.. + /// The bias vector. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The strides of the sliding window for spatial dimensions. + @noDerivative public let strides: (Int, Int) + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + @noDerivative public let paddingIndex: Int + + /// Creates a `TransposedConv2D` layer with the specified filter, bias, + /// activation function, strides, and padding. + /// + /// - Parameters: + /// - filter: The 4-D convolution kernel. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + strides: (Int, Int), + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.strides = strides + self.padding = padding + self.paddingIndex = padding == .same ? 0 : 1 + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let batchSize = input.shape[0] + let w = (input.shape[1] - (1 * paddingIndex)) * + strides.0 + (filter.shape[0] * paddingIndex) + let h = (input.shape[2] - (1 * paddingIndex)) * + strides.1 + (filter.shape[1] * paddingIndex) + let c = filter.shape[2] + let newShape = Tensor([Int32(batchSize), Int32(w), Int32(h), Int32(c)]) + return activation(input.conv2DBackpropInput(shape: newShape, filter: filter, + strides: (1, strides.0, strides.1, 1), + padding: padding) + bias) + } +} + +public extension TransposedConv2D { + /// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int, Int), + strides: (Int, Int) = (1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + filterShape.0, filterShape.1, filterShape.2, filterShape.3]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape, generator: &generator), + bias: Tensor(zeros: TensorShape([filterShape.3])), + activation: activation, + strides: strides, + padding: padding) + } +} + +public extension TransposedConv2D { + /// Creates a `TransposedConv2D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int, Int), + strides: (Int, Int) = (1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min.. Tensor { + let noise = Tensor(randomUniform: shape) + let keepMask = noise .>= Scalar(probability) + let keepProbability = Scalar(1.0 - probability) + return self * Tensor(keepMask) / Tensor(keepProbability) + } +} + + +/// A dropout layer. +/// +/// Dropout consists in randomly setting a fraction of input units to `0` at each update during +/// training time, which helps prevent overfitting. +@_fixed_layout +public struct Dropout: Layer { + @noDerivative public let probability: Double + + /// Creates a dropout layer. + /// + /// - Parameter probability: The drop probability. + public init(probability: Double) { + self.probability = probability + } + + @differentiable + private func applyingTraining(to input: Tensor) -> Tensor { + return input.droppingOut(probability: probability) + } + + @differentiable + private func applyingInference(to input: Tensor) -> Tensor { + return input + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable(vjp: _vjpApplied(to:)) + public func call(_ input: Tensor) -> Tensor { + switch Context.local.learningPhase { + case .training: + return applyingTraining(to: input) + case .inference: + return applyingInference(to: input) + } + } + + @usableFromInline + func _vjpApplied(to input: Tensor) -> + (Tensor, (Tensor) -> + (Dropout.CotangentVector, Tensor)) { + switch Context.local.learningPhase { + case .training: + return valueWithPullback(at: input) { + $0.applyingTraining(to: $1) + } + case .inference: + return valueWithPullback(at: input) { + $0.applyingInference(to: $1) + } + } + } +} + + + +/// A flatten layer. +/// +/// A flatten layer flattens the input when applied without affecting the batch size. +@_fixed_layout +public struct Flatten: Layer { + /// Creates a flatten layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let batchSize = input.shape[0] + let remaining = input.shape[1..: Layer { + /// The target shape. + @noDerivative public let shape: Tensor + + // TF-331 workaround: + @usableFromInline + internal var _nontrivial = Tensor(0) + + /// Creates a reshape layer. + /// + /// - Parameter shape: The target shape, represented by a tensor. + public init(shape: Tensor) { + self.shape = shape + } + + /// Creates a reshape layer. + /// + /// - Parameter shape: The target shape. + public init(_ shape: TensorShape) { + self.init(shape: Tensor(shape.dimensions.map(Int32.init))) + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.reshaped(toShape: shape) + } +} + + + +/// A densely-connected neural network layer. +/// +/// `Dense` implements the operation `activation(matmul(input, weight) + bias)`, where `weight` is +/// a weight matrix, `bias` is a bias vector, and `activation` is an element-wise activation +/// function. +@_fixed_layout +public struct Dense: Layer { + /// The weight matrix. + public var weight: Tensor + /// The bias vector. + public var bias: Tensor + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + + public init( + weight: Tensor, + bias: Tensor, + activation: @escaping Activation + ) { + self.weight = weight + self.bias = bias + self.activation = activation + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return activation(matmul(input, weight) + bias) + } +} + +public extension Dense { + /// Creates a `Dense` layer with the specified input size, output size, and element-wise + /// activation function. The weight matrix is created with shape `[inputSize, outputSize]` and + /// is initialized using Glorot uniform initialization with the specified generator. The bias + /// vector is created with shape `[outputSize]` and is initialized with zeros. + /// + /// - Parameters: + /// - inputSize: The dimensionality of the input space. + /// - outputSize: The dimensionality of the output space. + /// - activation: The activation function to use. The default value is `identity(_:)`. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(inputSize:outputSize:activation:seed:)` for faster random initialization. + init( + inputSize: Int, + outputSize: Int, + activation: @escaping Activation = identity, + generator: inout G + ) { + self.init(weight: Tensor(glorotUniform: [inputSize, outputSize], + generator: &generator), + bias: Tensor(zeros: [outputSize]), + activation: activation) + } + + init(inputSize: Int, outputSize: Int, activation: @escaping Activation = identity) { + self.init(inputSize: inputSize, outputSize: outputSize, activation: activation, + generator: &PhiloxRandomNumberGenerator.global) + } +} + +public extension Dense { + /// Creates a `Dense` layer with the specified input size, output size, and element-wise + /// activation function. The weight matrix is created with shape `[inputSize, outputSize]` and + /// is initialized using Glorot uniform initialization with the specified seed. The bias vector + /// is created with shape `[outputSize]` and is initialized with zeros. + /// + /// - Parameters: + /// - inputSize: The dimensionality of the input space. + /// - outputSize: The dimensionality of the output space. + /// - activation: The activation function to use. The default value is `identity(_:)`. + /// - seed: The random seed for initialization. The default value is random. + init( + inputSize: Int, + outputSize: Int, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min.. Output +} + +public extension Layer { + /// Returns the inference output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The inference output. + @differentiable + func inferring(from input: Input) -> Output { + return withLearningPhase(LearningPhase.inference) { self(input) } + } + + // TODO(rxwei): Remove this custom VJP once differentiation supports currying. + @differentiating(inferring(from:)) + @usableFromInline + internal func _vjpInferring(from input: Input) + -> (value: Output, pullback: (Output.CotangentVector) + -> (CotangentVector, Input.CotangentVector)) { + return withLearningPhase(LearningPhase.inference) { + let (output, pullback) = appliedForBackpropagation(to: input) + return (output, { v in pullback(v) }) + } + } + + typealias Backpropagator = (_ direction: Output.CotangentVector) + -> (layerGradient: CotangentVector, inputGradient: Input.CotangentVector) + + /// Returns the inference output and the backpropagation function obtained from applying the + /// layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: A tuple containing the output and the backpropagation function. The + /// backpropagation function (a.k.a. backpropagator) takes a direction vector and returns the + /// gradients at the layer and at the input, respectively. + func appliedForBackpropagation(to input: Input) + -> (output: Output, backpropagator: Backpropagator) { + let (out, pullback) = valueWithPullback(at: input) { layer, input in + return layer(input) + } + return (out, pullback) + } +} + +public extension Differentiable { + /// Returns the output computed by applying a sequence of layers to the previous layer's output, + /// except that the first layer's input is `self`. + /// + /// - Parameters: + /// - l1: The first layer. + /// - l2: The second layer. + /// - Returns: The final layer's output after sequential application. + @differentiable + func sequenced(through l1: L1, _ l2: L2) -> L2.Output + where L1.Input == Self, L1.Output == L2.Input { + let o1 = l1(self) + return l2(o1) + } + + /// Returns the output computed by applying a sequence of layers to the previous layer's output, + /// except that the first layer's input is `self`. + /// + /// - Parameters: + /// - l1: The first layer. + /// - l2: The second layer. + /// - l3: The third layer. + /// - Returns: The final layer's output after sequential application. + @differentiable + func sequenced(through l1: L1, _ l2: L2, _ l3: L3) -> L3.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input { + let o1 = l1(self) + let o2 = l2(o1) + return l3(o2) + } + + /// Returns the output computed by applying a sequence of layers to the previous layer's output, + /// except that the first layer's input is `self`. + /// + /// - Parameters: + /// - l1: The first layer. + /// - l2: The second layer. + /// - l3: The third layer. + /// - l4: The fourth layer. + /// - Returns: The final layer's output after sequential application. + @differentiable + func sequenced( + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4 + ) -> L4.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, + L3.Output == L4.Input { + let o1 = l1(self) + let o2 = l2(o1) + let o3 = l3(o2) + return l4(o3) + } + + /// Returns the output computed by applying a sequence of layers to the previous layer's output, + /// except that the first layer's input is `self`. + /// + /// - Parameters: + /// - l1: The first layer. + /// - l2: The second layer. + /// - l3: The third layer. + /// - l4: The third layer. + /// - l5: The fifth layer. + /// - Returns: The final layer's output after sequential application. + @differentiable + func sequenced( + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5 + ) -> L5.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, + L4.Output == L5.Input { + let o1 = l1(self) + let o2 = l2(o1) + let o3 = l3(o2) + let o4 = l4(o3) + return l5(o4) + } + + /// Returns the output computed by applying a sequence of layers to the previous layer's output, + /// except that the first layer's input is `self`. + /// + /// - Parameters: + /// - l1: The first layer. + /// - l2: The second layer. + /// - l3: The third layer. + /// - l4: The third layer. + /// - l5: The fifth layer. + /// - l6: The sixth layer. + /// - Returns: The final layer's output after sequential application. + @differentiable + func sequenced( + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6 + ) -> L6.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, + L4.Output == L5.Input, L5.Output == L6.Input { + let o1 = l1(self) + let o2 = l2(o1) + let o3 = l3(o2) + let o4 = l4(o3) + let o5 = l5(o4) + return l6(o5) + } +} + + +/// A mutable, shareable, owning reference to a tensor. +public final class Parameter { + public var value: Tensor + public init(_ value: Tensor) { + self.value = value + } +} diff --git a/Sources/DeepLearning/Layers/Normalization.swift b/Sources/DeepLearning/Layers/Normalization.swift new file mode 100644 index 000000000..4d70dd261 --- /dev/null +++ b/Sources/DeepLearning/Layers/Normalization.swift @@ -0,0 +1,202 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +@_exported import TensorFlow +#endif + + +/// A batch normalization layer. +/// +/// Normalizes the activations of the previous layer at each batch, i.e. applies a transformation +/// that maintains the mean activation close to `0` and the activation standard deviation close to +/// `1`. +/// +/// Reference: [Batch Normalization: Accelerating Deep Network Training by Reducing Internal +/// Covariate Shift](https://arxiv.org/abs/1502.03167). +@_fixed_layout +public struct BatchNorm: Layer { + /// The feature dimension. + @noDerivative public let axis: Int + /// The momentum for the running mean and running variance. + @noDerivative public let momentum: Tensor + /// The offset value, also known as beta. + public var offset: Tensor + /// The scale value, also known as gamma. + public var scale: Tensor + /// The variance epsilon value. + @noDerivative public let epsilon: Tensor + /// The running mean. + @noDerivative public let runningMean: Parameter + /// The running variance. + @noDerivative public let runningVariance: Parameter + + /// Creates a batch normalization layer. + /// + /// - Parameters: + /// - axis: The axis that should not be normalized (typically the feature axis). + /// - momentum: The momentum for the moving average. + /// - offset: The offset to be added to the normalized tensor. + /// - scale: The scale to multiply the normalized tensor by. + /// - epsilon: A small scalar added to the denominator to improve numerical stability. + /// - runningMean: The running mean. + /// - runningVariance: The running variance. + public init( + axis: Int, + momentum: Tensor, + offset: Tensor, + scale: Tensor, + epsilon: Tensor, + runningMean: Tensor, + runningVariance: Tensor + ) { + self.axis = axis + self.momentum = momentum + self.offset = offset + self.scale = scale + self.epsilon = epsilon + self.runningMean = Parameter(runningMean) + self.runningVariance = Parameter(runningVariance) + } + + @differentiable + private func applyingTraining(to input: Tensor) -> Tensor { + let positiveAxis = (input.rank + axis) % input.rank + var normalizedAxes = Array(0..) -> Tensor { + let inv = rsqrt(runningVariance.value + epsilon) * scale + return (input - runningMean.value) * inv + offset + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable(vjp: _vjpApplied(to:)) + public func call(_ input: Tensor) -> Tensor { + switch Context.local.learningPhase { + case .training: + return applyingTraining(to: input) + case .inference: + return applyingInference(to: input) + } + } + + @usableFromInline + func _vjpApplied(to input: Tensor) -> + (Tensor, (Tensor) -> + (BatchNorm.CotangentVector, Tensor)) { + switch Context.local.learningPhase { + case .training: + return valueWithPullback(at: input) { + $0.applyingTraining(to: $1) + } + case .inference: + return valueWithPullback(at: input) { + $0.applyingInference(to: $1) + } + } + } + + /// Creates a batch normalization layer. + /// + /// - Parameters: + /// - featureCount: The number of features. + /// - axis: The axis that should be normalized (typically the features axis). + /// - momentum: The momentum for the moving average. + /// - epsilon: A small scalar added to the denominator to improve numerical stability. + public init(featureCount: Int, + axis: Int = -1, + momentum: Tensor = Tensor(0.99), + epsilon: Tensor = Tensor(0.001)) { + self.axis = axis + self.momentum = momentum + self.scale = Tensor(ones: [featureCount]) + self.offset = Tensor(zeros: [featureCount]) + self.epsilon = epsilon + self.runningMean = Parameter(Tensor(0)) + self.runningVariance = Parameter(Tensor(1)) + } +} + + +/// A layer that applies layer normalization over a mini-batch of inputs. +/// +/// Reference: [Layer Normalization](https://arxiv.org/abs/1607.06450). +@_fixed_layout +public struct LayerNorm: Layer { + /// The offset value, also known as beta. + public var offset: Tensor + /// The scale value, also known as gamma. + public var scale: Tensor + /// The axis. + @noDerivative public let axis: Int + /// The variance epsilon value. + @noDerivative public let epsilon: Tensor + + /// Creates a layer normalization layer. + public init( + offset: Tensor, + scale: Tensor, + axis: Int, + epsilon: Tensor + ) { + self.offset = offset + self.scale = scale + self.axis = axis + self.epsilon = epsilon + } + + /// Creates a layer normalization layer. + /// + /// - Parameters: + /// - featureCount: The number of features. + /// - axis: The axis that should be normalized. + /// - epsilon: The small scalar added to variance. + public init(featureCount: Int, + axis: Int, + epsilon: Tensor = Tensor(0.001)) { + self.init( + offset: Tensor(zeros: [featureCount]), + scale: Tensor(ones: [featureCount]), + axis: axis, + epsilon: epsilon + ) + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let mean = input.mean(alongAxes: axis) + let variance = input.variance(alongAxes: axis) + let inv = rsqrt(variance + epsilon) * scale + return (input - mean) * inv + offset + } +} diff --git a/Sources/DeepLearning/Layers/Pooling.Swift b/Sources/DeepLearning/Layers/Pooling.Swift new file mode 100644 index 000000000..1a8e922d2 --- /dev/null +++ b/Sources/DeepLearning/Layers/Pooling.Swift @@ -0,0 +1,240 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +@_exported import TensorFlow +#endif + + +/// An average pooling layer for temporal data. +@_fixed_layout +public struct AvgPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates an average pooling layer. + /// + /// - Parameters: + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = poolSize + self.stride = stride + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.expandingShape(at: 1).averagePooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) + } +} + +/// An average pooling layer for spatial data. +@_fixed_layout +public struct AvgPool2D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: (Int, Int, Int, Int) + /// The strides of the sliding window for each dimension of a 4-D input. + /// Strides in non-spatial dimensions must be `1`. + @noDerivative let strides: (Int, Int, Int, Int) + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a average pooling layer. + public init( + poolSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), + padding: Padding + ) { + self.poolSize = poolSize + self.strides = strides + self.padding = padding + } + + /// Creates a average pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { + self.poolSize = (1, poolSize.0, poolSize.1, 1) + self.strides = (1, strides.0, strides.1, 1) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.averagePooled(kernelSize: poolSize, strides: strides, padding: padding) + } +} + + +/// A global average pooling layer for temporal data. +@_fixed_layout +public struct GlobalAvgPool1D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: 1) + } +} + +/// A global average pooling layer for spatial data. +@_fixed_layout +public struct GlobalAvgPool2D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: [1, 2]) + } +} + +/// A global average pooling layer for spatial and spatio-temporal data. +@_fixed_layout +public struct GlobalAvgPool3D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: [1, 2, 3]) + } +} + + +/// A max pooling layer for temporal data. +@_fixed_layout +public struct MaxPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = poolSize + self.stride = stride + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.expandingShape(at: 1).maxPooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) + } +} + +/// A max pooling layer for spatial data. +@_fixed_layout +public struct MaxPool2D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: (Int, Int, Int, Int) + /// The strides of the sliding window for each dimension of a 4-D input. + /// Strides in non-spatial dimensions must be `1`. + @noDerivative let strides: (Int, Int, Int, Int) + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a max pooling layer. + public init( + poolSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), + padding: Padding + ) { + self.poolSize = poolSize + self.strides = strides + self.padding = padding + } + + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { + self.poolSize = (1, poolSize.0, poolSize.1, 1) + self.strides = (1, strides.0, strides.1, 1) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.maxPooled( + kernelSize: poolSize, strides: strides, padding: padding) + } +} diff --git a/Sources/DeepLearning/Layers/Recurrent.Swift b/Sources/DeepLearning/Layers/Recurrent.Swift new file mode 100644 index 000000000..41e51cac2 --- /dev/null +++ b/Sources/DeepLearning/Layers/Recurrent.Swift @@ -0,0 +1,192 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +@_exported import TensorFlow +#endif + + + +/// An input to a recurrent neural network. +public struct RNNCellInput: Differentiable { + /// The input at the current time step. + public var input: Input + /// The previous state. + public var state: State + + @differentiable + public init(input: Input, state: State) { + self.input = input + self.state = state + } +} + +/// An output to a recurrent neural network. +public struct RNNCellOutput: Differentiable { + /// The output at the current time step. + public var output: Output + /// The current state. + public var state: State + + @differentiable + public init(output: Output, state: State) { + self.output = output + self.state = state + } +} + +/// A recurrent neural network cell. +public protocol RNNCell: Layer where Input == RNNCellInput, + Output == RNNCellOutput { + /// The input at a time step. + associatedtype TimeStepInput: Differentiable + /// The output at a time step. + associatedtype TimeStepOutput: Differentiable + /// The state that may be preserved across time steps. + associatedtype State: Differentiable + /// The zero state. + var zeroState: State { get } +} + +public extension RNNCell { + /// Returns the new state obtained from applying the RNN cell to the input at the current time + /// step and the previous state. + /// + /// - Parameters: + /// - timeStepInput: The input at the current time step. + /// - previousState: The previous state of the RNN cell. + /// - Returns: The output. + @differentiable + func call(input: TimeStepInput, state: State) -> RNNCellOutput { + return self(RNNCellInput(input: input, state: state)) + } +} + +/// A Simple RNN Cell. +public struct SimpleRNNCell: RNNCell { + public var weight: Tensor + public var bias: Tensor + + @noDerivative public var stateShape: TensorShape { + return TensorShape([1, weight.shape[1]]) + } + + public var zeroState: Tensor { + return Tensor(zeros: stateShape) + } + + public typealias State = Tensor + public typealias TimeStepInput = Tensor + public typealias TimeStepOutput = State + public typealias Input = RNNCellInput + public typealias Output = RNNCellOutput + + /// Creates a `SimpleRNNCell` with the specified input size and hidden state size. + /// + /// - Parameters: + /// - inputSize: The number of features in 2-D input tensors. + /// - hiddenSize: The number of features in 2-D hidden states. + public init(inputSize: Int, hiddenSize: Int) { + let concatenatedInputSize = inputSize + hiddenSize + self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize]) + self.bias = Tensor(zeros: [hiddenSize]) + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The hidden state. + @differentiable + public func call(_ input: Input) -> Output { + let concatenatedInput = input.input.concatenated(with: input.state, alongAxis: 1) + let newState = matmul(concatenatedInput, weight) + bias + return Output(output: newState, state: newState) + } +} + +/// An LSTM Cell. +public struct LSTMCell: RNNCell { + public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor + public var inputBias, updateBias, forgetBias, outputBias: Tensor + + @noDerivative public var stateShape: TensorShape { + return TensorShape([1, inputWeight.shape[1]]) + } + + public var zeroState: State { + return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape)) + } + + public typealias TimeStepInput = Tensor + public typealias TimeStepOutput = State + public typealias Input = RNNCellInput + public typealias Output = RNNCellOutput + + /// Creates a `LSTMCell` with the specified input size and hidden state size. + /// + /// - Parameters: + /// - inputSize: The number of features in 2-D input tensors. + /// - hiddenSize: The number of features in 2-D hidden states. + public init(inputSize: Int, hiddenSize: Int) { + let concatenatedInputSize = inputSize + hiddenSize + let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize]) + let gateBiasShape = TensorShape([hiddenSize]) + self.inputWeight = Tensor(glorotUniform: gateWeightShape) + self.inputBias = Tensor(zeros: gateBiasShape) + self.updateWeight = Tensor(glorotUniform: gateWeightShape) + self.updateBias = Tensor(zeros: gateBiasShape) + self.forgetWeight = Tensor(glorotUniform: gateWeightShape) + self.forgetBias = Tensor(ones: gateBiasShape) + self.outputWeight = Tensor(glorotUniform: gateWeightShape) + self.outputBias = Tensor(zeros: gateBiasShape) + } + + public struct State: Differentiable { + public var cell: Tensor + public var hidden: Tensor + + @differentiable + public init(cell: Tensor, hidden: Tensor) { + self.cell = cell + self.hidden = hidden + } + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The hidden state. + @differentiable + public func call(_ input: Input) -> Output { + let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1) + + let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias) + let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias) + let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias) + let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias) + + let newCellState = input.state.cell * forgetGate + inputGate * updateGate + let newHiddenState = tanh(newCellState) * outputGate + + let newState = State(cell: newCellState, hidden: newHiddenState) + + return Output(output: newState, state: newState) + } +} diff --git a/Sources/DeepLearning/Layers/Upsampling.swift b/Sources/DeepLearning/Layers/Upsampling.swift new file mode 100644 index 000000000..aaa1061f7 --- /dev/null +++ b/Sources/DeepLearning/Layers/Upsampling.swift @@ -0,0 +1,72 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +@_exported import TensorFlow +#endif + + +/// An upsampling layer for 1-D inputs. +@_fixed_layout +public struct UpSampling1D: Layer { + @noDerivative public let size: Int + + /// Creates an upsampling layer. + /// + /// - Parameter size: The upsampling factor for timesteps. + public init(size: Int) { + self.size = size + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let shape = input.shape + let (batchSize, timesteps, channels) = (shape[0], shape[1], shape[2]) + let scaleOnes = Tensor(ones: [1, 1, size, 1]) + let upSampling = input.reshaped(to: [batchSize, timesteps, 1, channels]) * scaleOnes + return upSampling.reshaped(to: [batchSize, timesteps * size, channels]) + } +} + +/// An upsampling layer for 2-D inputs. +@_fixed_layout +public struct UpSampling2D: Layer { + @noDerivative public let size: Int + + /// Creates an upsampling layer. + /// + /// - Parameter size: The upsampling factor for rows and columns. + public init(size: Int) { + self.size = size + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + let shape = input.shape + let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]) + let scaleOnes = Tensor(ones: [1, 1, size, 1, size, 1]) + let upSampling = input.reshaped(to: [batchSize, height, 1, width, 1, channels]) * scaleOnes + return upSampling.reshaped(to: [batchSize, height * size, width * size, channels]) + } +} From b6a5bf37b9a23e15329b22a1d9b2182bafb07f91 Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Sun, 21 Apr 2019 06:02:53 +0530 Subject: [PATCH 2/8] Review Changes and merge error --- .../DeepLearning/Layers/Convolutional.swift | 1 - Sources/DeepLearning/Layers/Core.swift | 6 -- Sources/DeepLearning/Layers/Layer.swift | 3 - .../DeepLearning/Layers/Normalization.swift | 2 - Sources/DeepLearning/Layers/Pooling.Swift | 3 - Sources/DeepLearning/Layers/Recurrent.Swift | 92 ++++++++++++++++++- Sources/DeepLearning/Layers/Upsampling.swift | 1 - 7 files changed, 89 insertions(+), 19 deletions(-) diff --git a/Sources/DeepLearning/Layers/Convolutional.swift b/Sources/DeepLearning/Layers/Convolutional.swift index ad0f2bb59..674a2b4e7 100644 --- a/Sources/DeepLearning/Layers/Convolutional.swift +++ b/Sources/DeepLearning/Layers/Convolutional.swift @@ -16,7 +16,6 @@ @_exported import TensorFlow #endif - /// A 1-D convolution layer (e.g. temporal convolution over a time-series). /// /// This layer creates a convolution filter that is convolved with the layer input to produce a diff --git a/Sources/DeepLearning/Layers/Core.swift b/Sources/DeepLearning/Layers/Core.swift index 48a22aae1..12328b219 100644 --- a/Sources/DeepLearning/Layers/Core.swift +++ b/Sources/DeepLearning/Layers/Core.swift @@ -16,7 +16,6 @@ @_exported import TensorFlow #endif - public extension Tensor where Scalar: TensorFlowFloatingPoint { /// Computes dropout given a probability. @differentiable(wrt: self where Scalar: Differentiable) @@ -28,7 +27,6 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { } } - /// A dropout layer. /// /// Dropout consists in randomly setting a fraction of input units to `0` at each update during @@ -86,8 +84,6 @@ public struct Dropout: Layer { } } - - /// A flatten layer. /// /// A flatten layer flattens the input when applied without affecting the batch size. @@ -144,8 +140,6 @@ public struct Reshape: Layer { } } - - /// A densely-connected neural network layer. /// /// `Dense` implements the operation `activation(matmul(input, weight) + bias)`, where `weight` is diff --git a/Sources/DeepLearning/Layers/Layer.swift b/Sources/DeepLearning/Layers/Layer.swift index acdb69225..c3d50a57b 100644 --- a/Sources/DeepLearning/Layers/Layer.swift +++ b/Sources/DeepLearning/Layers/Layer.swift @@ -16,8 +16,6 @@ @_exported import TensorFlow #endif - - /// A neural network layer. /// /// Types that conform to `Layer` represent functions that map inputs to outputs. They may have an @@ -182,7 +180,6 @@ public extension Differentiable { } } - /// A mutable, shareable, owning reference to a tensor. public final class Parameter { public var value: Tensor diff --git a/Sources/DeepLearning/Layers/Normalization.swift b/Sources/DeepLearning/Layers/Normalization.swift index 4d70dd261..0827591d2 100644 --- a/Sources/DeepLearning/Layers/Normalization.swift +++ b/Sources/DeepLearning/Layers/Normalization.swift @@ -16,7 +16,6 @@ @_exported import TensorFlow #endif - /// A batch normalization layer. /// /// Normalizes the activations of the previous layer at each batch, i.e. applies a transformation @@ -142,7 +141,6 @@ public struct BatchNorm: Layer { } } - /// A layer that applies layer normalization over a mini-batch of inputs. /// /// Reference: [Layer Normalization](https://arxiv.org/abs/1607.06450). diff --git a/Sources/DeepLearning/Layers/Pooling.Swift b/Sources/DeepLearning/Layers/Pooling.Swift index 1a8e922d2..0c14c9073 100644 --- a/Sources/DeepLearning/Layers/Pooling.Swift +++ b/Sources/DeepLearning/Layers/Pooling.Swift @@ -16,7 +16,6 @@ @_exported import TensorFlow #endif - /// An average pooling layer for temporal data. @_fixed_layout public struct AvgPool1D: Layer { @@ -101,7 +100,6 @@ public struct AvgPool2D: Layer { } } - /// A global average pooling layer for temporal data. @_fixed_layout public struct GlobalAvgPool1D: Layer { @@ -153,7 +151,6 @@ public struct GlobalAvgPool3D: Layer { } } - /// A max pooling layer for temporal data. @_fixed_layout public struct MaxPool1D: Layer { diff --git a/Sources/DeepLearning/Layers/Recurrent.Swift b/Sources/DeepLearning/Layers/Recurrent.Swift index 41e51cac2..5e484f574 100644 --- a/Sources/DeepLearning/Layers/Recurrent.Swift +++ b/Sources/DeepLearning/Layers/Recurrent.Swift @@ -16,8 +16,6 @@ @_exported import TensorFlow #endif - - /// An input to a recurrent neural network. public struct RNNCellInput: Differentiable { /// The input at the current time step. @@ -113,7 +111,7 @@ public struct SimpleRNNCell: RNNCell { @differentiable public func call(_ input: Input) -> Output { let concatenatedInput = input.input.concatenated(with: input.state, alongAxis: 1) - let newState = matmul(concatenatedInput, weight) + bias + let newState = tanh(matmul(concatenatedInput, weight) + bias) return Output(output: newState, state: newState) } } @@ -190,3 +188,91 @@ public struct LSTMCell: RNNCell { return Output(output: newState, state: newState) } } + +public struct RNN: Layer { + public typealias Input = [Cell.TimeStepInput] + public typealias Output = [Cell.TimeStepOutput] + + public var cell: Cell + + public init(_ cell: @autoclosure () -> Cell) { + self.cell = cell() + } + + @differentiable(wrt: (self, input), vjp: _vjpCall(_:initialState:)) + public func call(_ input: [Cell.TimeStepInput], + initialState: Cell.State) -> [Cell.TimeStepOutput] { + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + for timestep in input { + let output = cell(input: timestep, state: currentHiddenState) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + } + return timeStepOutputs + } + + @usableFromInline + internal func _vjpCall( + _ inputs: [Cell.TimeStepInput], initialState: Cell.State + ) -> ([Cell.TimeStepOutput], + (Array.CotangentVector) + -> (CotangentVector, Array.CotangentVector)) { + let timeStepCount = inputs.count + var currentHiddenState = cell.zeroState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + timeStepOutputs.reserveCapacity(timeStepCount) + var backpropagators: [Cell.Backpropagator] = [] + backpropagators.reserveCapacity(timeStepCount) + for timestep in inputs { + let (output, backpropagator) = + cell.appliedForBackpropagation(to: .init(input: timestep, + state: currentHiddenState)) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + backpropagators.append(backpropagator) + } + return (timeStepOutputs, { 𝛁outputs in + precondition(𝛁outputs.base.count == timeStepCount, + "The number of output gradients must equal the number of time steps") + var 𝛁cell = Cell.CotangentVector.zero + var 𝛁state = Cell.State.CotangentVector.zero + var reversed𝛁inputs: [Cell.TimeStepInput.CotangentVector] = [] + reversed𝛁inputs.reserveCapacity(timeStepCount) + for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() { + let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state)) + 𝛁cell = new𝛁cell + 𝛁state = 𝛁input.state + reversed𝛁inputs.append(𝛁input.input) + } + return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed()))) + }) + } + + @differentiable(wrt: (self, inputs)) + public func call(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] { + return self(inputs, initialState: cell.zeroState.withoutDerivative()) + } + + /* TODO: Uncomment once control flow and differentiation through force unwrapping is supported. + @differentiable(wrt: (self, inputs)) + public func lastOutput(from inputs: [Cell.TimeStepInput], + initialState: Cell.State) -> Cell.TimeStepOutput { + precondition(!inputs.isEmpty, "inputs cannot be empty") + return self(inputs, initialState: initialState).last! + } + + @differentiable(wrt: (self, inputs)) + public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput { + precondition(!inputs.isEmpty, "inputs cannot be empty") + return self(inputs, initialState: cell.zeroState).last! + } + */ +} + +extension RNN: Equatable where Cell: Equatable {} +extension RNN: AdditiveArithmetic where Cell: AdditiveArithmetic {} +extension RNN: VectorNumeric where Cell: VectorNumeric {} + +public typealias SimpleRNN = RNN> +public typealias LSTM = RNN> diff --git a/Sources/DeepLearning/Layers/Upsampling.swift b/Sources/DeepLearning/Layers/Upsampling.swift index aaa1061f7..f476a58c0 100644 --- a/Sources/DeepLearning/Layers/Upsampling.swift +++ b/Sources/DeepLearning/Layers/Upsampling.swift @@ -16,7 +16,6 @@ @_exported import TensorFlow #endif - /// An upsampling layer for 1-D inputs. @_fixed_layout public struct UpSampling1D: Layer { From 63ad1ef64b7ef9eec9803cda25fa55a734a39d9d Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Sun, 19 May 2019 15:47:58 +0530 Subject: [PATCH 3/8] updating directory to reflect PR changes --- .../DeepLearning/Layers/Convolutional.swift | 9 +- Sources/DeepLearning/Layers/Core.swift | 14 +- Sources/DeepLearning/Layers/Layer.swift | 12 +- .../DeepLearning/Layers/Normalization.swift | 8 +- Sources/DeepLearning/Layers/Pooling.Swift | 272 ++++++++++++------ Sources/DeepLearning/Layers/Recurrent.Swift | 65 +++-- Sources/DeepLearning/Layers/Upsampling.swift | 35 ++- 7 files changed, 275 insertions(+), 140 deletions(-) diff --git a/Sources/DeepLearning/Layers/Convolutional.swift b/Sources/DeepLearning/Layers/Convolutional.swift index 674a2b4e7..5861845ee 100644 --- a/Sources/DeepLearning/Layers/Convolutional.swift +++ b/Sources/DeepLearning/Layers/Convolutional.swift @@ -60,8 +60,7 @@ public struct Conv1D: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer `[batchCount, width, inputChannels]`. + /// - Parameter input: The input to the layer `[batchCount, width, inputChannels]`. /// - Returns: The output `[batchCount, newWidth, outputChannels]`. @differentiable public func call(_ input: Tensor) -> Tensor { @@ -179,8 +178,7 @@ public struct Conv2D: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { @@ -298,8 +296,7 @@ public struct TransposedConv2D: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { diff --git a/Sources/DeepLearning/Layers/Core.swift b/Sources/DeepLearning/Layers/Core.swift index 12328b219..0eea42c9e 100644 --- a/Sources/DeepLearning/Layers/Core.swift +++ b/Sources/DeepLearning/Layers/Core.swift @@ -54,8 +54,7 @@ public struct Dropout: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable(vjp: _vjpApplied(to:)) public func call(_ input: Tensor) -> Tensor { @@ -70,7 +69,7 @@ public struct Dropout: Layer { @usableFromInline func _vjpApplied(to input: Tensor) -> (Tensor, (Tensor) -> - (Dropout.CotangentVector, Tensor)) { + (Dropout.TangentVector, Tensor)) { switch Context.local.learningPhase { case .training: return valueWithPullback(at: input) { @@ -94,8 +93,7 @@ public struct Flatten: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { @@ -131,8 +129,7 @@ public struct Reshape: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { @@ -167,8 +164,7 @@ public struct Dense: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { diff --git a/Sources/DeepLearning/Layers/Layer.swift b/Sources/DeepLearning/Layers/Layer.swift index c3d50a57b..9c2011c60 100644 --- a/Sources/DeepLearning/Layers/Layer.swift +++ b/Sources/DeepLearning/Layers/Layer.swift @@ -31,8 +31,7 @@ public protocol Layer: Differentiable & KeyPathIterable /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable func call(_ input: Input) -> Output @@ -52,16 +51,16 @@ public extension Layer { @differentiating(inferring(from:)) @usableFromInline internal func _vjpInferring(from input: Input) - -> (value: Output, pullback: (Output.CotangentVector) - -> (CotangentVector, Input.CotangentVector)) { + -> (value: Output, pullback: (Output.TangentVector) + -> (TangentVector, Input.TangentVector)) { return withLearningPhase(LearningPhase.inference) { let (output, pullback) = appliedForBackpropagation(to: input) return (output, { v in pullback(v) }) } } - typealias Backpropagator = (_ direction: Output.CotangentVector) - -> (layerGradient: CotangentVector, inputGradient: Input.CotangentVector) + typealias Backpropagator = (_ direction: Output.TangentVector) + -> (layerGradient: TangentVector, inputGradient: Input.TangentVector) /// Returns the inference output and the backpropagation function obtained from applying the /// layer to the given input. @@ -180,6 +179,7 @@ public extension Differentiable { } } + /// A mutable, shareable, owning reference to a tensor. public final class Parameter { public var value: Tensor diff --git a/Sources/DeepLearning/Layers/Normalization.swift b/Sources/DeepLearning/Layers/Normalization.swift index 0827591d2..96e251204 100644 --- a/Sources/DeepLearning/Layers/Normalization.swift +++ b/Sources/DeepLearning/Layers/Normalization.swift @@ -91,8 +91,7 @@ public struct BatchNorm: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable(vjp: _vjpApplied(to:)) public func call(_ input: Tensor) -> Tensor { @@ -107,7 +106,7 @@ public struct BatchNorm: Layer { @usableFromInline func _vjpApplied(to input: Tensor) -> (Tensor, (Tensor) -> - (BatchNorm.CotangentVector, Tensor)) { + (BatchNorm.TangentVector, Tensor)) { switch Context.local.learningPhase { case .training: return valueWithPullback(at: input) { @@ -187,8 +186,7 @@ public struct LayerNorm: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { diff --git a/Sources/DeepLearning/Layers/Pooling.Swift b/Sources/DeepLearning/Layers/Pooling.Swift index 0c14c9073..934b124a0 100644 --- a/Sources/DeepLearning/Layers/Pooling.Swift +++ b/Sources/DeepLearning/Layers/Pooling.Swift @@ -16,9 +16,10 @@ @_exported import TensorFlow #endif -/// An average pooling layer for temporal data. + +/// A max pooling layer for temporal data. @_fixed_layout -public struct AvgPool1D: Layer { +public struct MaxPool1D: Layer { /// The size of the sliding reduction window for pooling. @noDerivative let poolSize: Int /// The stride of the sliding window for temporal dimension. @@ -26,7 +27,7 @@ public struct AvgPool1D: Layer { /// The padding algorithm for pooling. @noDerivative let padding: Padding - /// Creates an average pooling layer. + /// Creates a max pooling layer. /// /// - Parameters: /// - poolSize: The size of the sliding reduction window for pooling. @@ -44,20 +45,19 @@ public struct AvgPool1D: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { - return input.expandingShape(at: 1).averagePooled( + return input.expandingShape(at: 1).maxPooled2D( kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) } } -/// An average pooling layer for spatial data. +/// A max pooling layer for spatial data. @_fixed_layout -public struct AvgPool2D: Layer { +public struct MaxPool2D: Layer { /// The size of the sliding reduction window for pooling. @noDerivative let poolSize: (Int, Int, Int, Int) /// The strides of the sliding window for each dimension of a 4-D input. @@ -66,7 +66,7 @@ public struct AvgPool2D: Layer { /// The padding algorithm for pooling. @noDerivative let padding: Padding - /// Creates a average pooling layer. + /// Creates a max pooling layer. public init( poolSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), @@ -77,83 +77,90 @@ public struct AvgPool2D: Layer { self.padding = padding } - /// Creates a average pooling layer. - /// - /// - Parameters: - /// - poolSize: Vertical and horizontal factors by which to downscale. - /// - strides: The strides. - /// - padding: The padding. - public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { - self.poolSize = (1, poolSize.0, poolSize.1, 1) - self.strides = (1, strides.0, strides.1, 1) - self.padding = padding - } - /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { - return input.averagePooled(kernelSize: poolSize, strides: strides, padding: padding) + return input.maxPooled2D( + kernelSize: poolSize, strides: strides, padding: padding) } } -/// A global average pooling layer for temporal data. -@_fixed_layout -public struct GlobalAvgPool1D: Layer { - /// Creates a global average pooling layer. - public init() {} - - /// Returns the output obtained from applying the layer to the given input. - /// - /// - Parameters: - /// - input: The input to the layer. - /// - Returns: The output. - @differentiable - public func call(_ input: Tensor) -> Tensor { - return input.mean(squeezingAxes: 1) - } +public extension MaxPool2D { + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { + self.init(poolSize: (1, poolSize.0, poolSize.1, 1), + strides: (1, strides.0, strides.1, 1), + padding: padding) + } } -/// A global average pooling layer for spatial data. +/// A max pooling layer for spatial or spatio-temporal data. @_fixed_layout -public struct GlobalAvgPool2D: Layer { - /// Creates a global average pooling layer. - public init() {} +public struct MaxPool3D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: (Int, Int, Int, Int, Int) + /// The strides of the sliding window for each dimension of a 5-D input. + /// Strides in non-spatial dimensions must be `1`. + @noDerivative let strides: (Int, Int, Int, Int, Int) + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a max pooling layer. + public init( + poolSize: (Int, Int, Int, Int, Int), + strides: (Int, Int, Int, Int, Int), + padding: Padding + ) { + self.poolSize = poolSize + self.strides = strides + self.padding = padding + } /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { - return input.mean(squeezingAxes: [1, 2]) + return input.maxPooled3D(kernelSize: poolSize, strides: strides, padding: padding) } } -/// A global average pooling layer for spatial and spatio-temporal data. -@_fixed_layout -public struct GlobalAvgPool3D: Layer { - /// Creates a global average pooling layer. - public init() {} +public extension MaxPool3D { + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + init(poolSize: (Int, Int, Int), strides: (Int, Int, Int), padding: Padding = .valid) { + self.init(poolSize: (1, poolSize.0, poolSize.1, poolSize.2, 1), + strides: (1, strides.0, strides.1, strides.2, 1), + padding: padding) + } +} - /// Returns the output obtained from applying the layer to the given input. - /// - /// - Parameters: - /// - input: The input to the layer. - /// - Returns: The output. - @differentiable - public func call(_ input: Tensor) -> Tensor { - return input.mean(squeezingAxes: [1, 2, 3]) - } +public extension MaxPool3D { + /// Creates a max pooling layer with the specified pooling window size and stride. All + /// pooling sizes and strides are the same. + init(poolSize: Int, stride: Int, padding: Padding = .valid) { + self.init(poolSize: (poolSize, poolSize, poolSize), + strides: (stride, stride, stride), + padding: padding) + } } -/// A max pooling layer for temporal data. +/// An average pooling layer for temporal data. @_fixed_layout -public struct MaxPool1D: Layer { +public struct AvgPool1D: Layer { /// The size of the sliding reduction window for pooling. @noDerivative let poolSize: Int /// The stride of the sliding window for temporal dimension. @@ -161,7 +168,7 @@ public struct MaxPool1D: Layer { /// The padding algorithm for pooling. @noDerivative let padding: Padding - /// Creates a max pooling layer. + /// Creates an average pooling layer. /// /// - Parameters: /// - poolSize: The size of the sliding reduction window for pooling. @@ -179,20 +186,19 @@ public struct MaxPool1D: Layer { /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { - return input.expandingShape(at: 1).maxPooled( + return input.expandingShape(at: 1).averagePooled2D( kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) } } -/// A max pooling layer for spatial data. +/// An average pooling layer for spatial data. @_fixed_layout -public struct MaxPool2D: Layer { +public struct AvgPool2D: Layer { /// The size of the sliding reduction window for pooling. @noDerivative let poolSize: (Int, Int, Int, Int) /// The strides of the sliding window for each dimension of a 4-D input. @@ -201,7 +207,7 @@ public struct MaxPool2D: Layer { /// The padding algorithm for pooling. @noDerivative let padding: Padding - /// Creates a max pooling layer. + /// Creates an average pooling layer. public init( poolSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), @@ -212,26 +218,130 @@ public struct MaxPool2D: Layer { self.padding = padding } - /// Creates a max pooling layer. + /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - poolSize: Vertical and horizontal factors by which to downscale. - /// - strides: The strides. - /// - padding: The padding. - public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { - self.poolSize = (1, poolSize.0, poolSize.1, 1) - self.strides = (1, strides.0, strides.1, 1) + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.averagePooled2D(kernelSize: poolSize, strides: strides, padding: padding) + } +} + +public extension AvgPool2D { + /// Creates an average pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { + self.init(poolSize: (1, poolSize.0, poolSize.1, 1), + strides: (1, strides.0, strides.1, 1), + padding: padding) + } +} + +/// An average pooling layer for spatial or spatio-temporal data. +@_fixed_layout +public struct AvgPool3D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: (Int, Int, Int, Int, Int) + /// The strides of the sliding window for each dimension of a 5-D input. + /// Strides in non-spatial dimensions must be `1`. + @noDerivative let strides: (Int, Int, Int, Int, Int) + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates an average pooling layer. + public init( + poolSize: (Int, Int, Int, Int, Int), + strides: (Int, Int, Int, Int, Int), + padding: Padding + ) { + self.poolSize = poolSize + self.strides = strides self.padding = padding } /// Returns the output obtained from applying the layer to the given input. /// - /// - Parameters: - /// - input: The input to the layer. + /// - Parameter input: The input to the layer. /// - Returns: The output. @differentiable public func call(_ input: Tensor) -> Tensor { - return input.maxPooled( - kernelSize: poolSize, strides: strides, padding: padding) + return input.averagePooled3D(kernelSize: poolSize, strides: strides, padding: padding) + } +} + +public extension AvgPool3D { + /// Creates an average pooling layer. + /// + /// - Parameters: + /// - poolSize: Vertical and horizontal factors by which to downscale. + /// - strides: The strides. + /// - padding: The padding. + init(poolSize: (Int, Int, Int), strides: (Int, Int, Int), padding: Padding = .valid) { + self.init(poolSize: (1, poolSize.0, poolSize.1, poolSize.2, 1), + strides: (1, strides.0, strides.1, strides.2, 1), + padding: padding) + } +} + +public extension AvgPool3D { + /// Creates an average pooling layer with the specified pooling window size and stride. All + /// pooling sizes and strides are the same. + init(poolSize: Int, strides: Int, padding: Padding = .valid) { + self.init(poolSize: (poolSize, poolSize, poolSize), + strides: (strides, strides, strides), + padding: padding) + } +} + +/// A global average pooling layer for temporal data. +@_fixed_layout +public struct GlobalAvgPool1D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: 1) + } +} + +/// A global average pooling layer for spatial data. +@_fixed_layout +public struct GlobalAvgPool2D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: [1, 2]) + } +} + +/// A global average pooling layer for spatial and spatio-temporal data. +@_fixed_layout +public struct GlobalAvgPool3D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.mean(squeezingAxes: [1, 2, 3]) } } diff --git a/Sources/DeepLearning/Layers/Recurrent.Swift b/Sources/DeepLearning/Layers/Recurrent.Swift index 5e484f574..d5fef0399 100644 --- a/Sources/DeepLearning/Layers/Recurrent.Swift +++ b/Sources/DeepLearning/Layers/Recurrent.Swift @@ -71,8 +71,8 @@ public extension RNNCell { } } -/// A Simple RNN Cell. -public struct SimpleRNNCell: RNNCell { +/// A simple RNN cell. +public struct SimpleRNNCell: RNNCell, VectorNumeric { public var weight: Tensor public var bias: Tensor @@ -80,11 +80,19 @@ public struct SimpleRNNCell: RNNCell { return TensorShape([1, weight.shape[1]]) } - public var zeroState: Tensor { - return Tensor(zeros: stateShape) + public var zeroState: State { + return State(Tensor(zeros: stateShape)) + } + + // TODO(TF-507): Revert to `typealias State = Tensor` after + // SR-10697 is fixed. + public struct State: Equatable, Differentiable, VectorNumeric, KeyPathIterable { + public let value: Tensor + public init(_ value: Tensor) { + self.value = value + } } - public typealias State = Tensor public typealias TimeStepInput = Tensor public typealias TimeStepOutput = State public typealias Input = RNNCellInput @@ -95,29 +103,29 @@ public struct SimpleRNNCell: RNNCell { /// - Parameters: /// - inputSize: The number of features in 2-D input tensors. /// - hiddenSize: The number of features in 2-D hidden states. - public init(inputSize: Int, hiddenSize: Int) { + /// - seed: The random seed for initialization. The default value is random. + public init(inputSize: Int, hiddenSize: Int, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..