From 897c26bce6d0d16e794639d16e5554102b4669fc Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 14 Apr 2019 04:43:33 +0100 Subject: [PATCH 1/2] [Layer] Remove the 'context' argument to `Layer.applied(to:in:)`. --- Sources/DeepLearning/Context.swift | 148 ++++++++++ Sources/DeepLearning/Layer.swift | 255 ++++++------------ Tests/DeepLearningTests/SequentialTests.swift | 11 +- .../DeepLearningTests/TrivialModelTests.swift | 10 +- 4 files changed, 242 insertions(+), 182 deletions(-) create mode 100644 Sources/DeepLearning/Context.swift diff --git a/Sources/DeepLearning/Context.swift b/Sources/DeepLearning/Context.swift new file mode 100644 index 000000000..266ab2902 --- /dev/null +++ b/Sources/DeepLearning/Context.swift @@ -0,0 +1,148 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) +import Darwin +#else +import Glibc +#endif + +/// A value that indicates the phase of using a machine learning model. +public enum LearningPhase { + case training + case inference +} + +/// A context that stores thread-local contextual information used by deep learning APIs such as +/// layers. +/// +/// Use `Context.local` to retrieve the current thread-local context. +/// +/// Examples: +/// +/// * Set the current learning phase to training so that layers like `BatchNorm` will +/// compute mean and variance when applied to inputs. +/// +/// ```swift +/// Context.local.learningPhase = .training +/// ``` +/// * Set the current learning phase to inference so that layers like `Dropout` will not drop out +/// units when applied to inputs. +/// +/// ```swift +/// Context.local.learningPhase = .inference +/// ``` +public struct Context { + /// The learning phase. + public var learningPhase: LearningPhase = .inference + + /// Creates a context with default properties. + public init() {} + + /// The current thread-local context. + /// + /// - Note: Accessing this property is thread-safe. + public static var local: Context { + _read { yield ContextManager.local.currentContext } + _modify { yield &ContextManager.local.currentContext } + } +} + +/// Calls the given closure within a context that has everything identical to the current context +/// except for the given learning phase. +/// +/// - Parameters: +/// - context: A context that will be set before the closure gets called and restored after the +/// closure returns. +/// - body: A nullary closure. If the closure has a return value, that value is also used as the +/// return value of the `withContext(_:_:)` function. +/// - Returns: The return value, if any, of the `body` closure. +public func withContext(_ context: Context, _ body: () throws -> R) rethrows -> R { + ContextManager.local.push(context) + defer { ContextManager.local.popContext() } + return try body() +} + +/// Calls the given closure within a context that has everything identical to the current context +/// except for the given learning phase. +/// +/// - Parameters: +/// - learningPhase: A learning phase that will be set before the closure gets called and restored +/// after the closure returns. +/// - body: A nullary closure. If the closure has a return value, that value is also used as the +/// return value of the `withLearningPhase(_:_:)` function. +/// - Returns: The return value, if any, of the `body` closure. +public func withLearningPhase(_ learningPhase: LearningPhase, + _ body: () throws -> R) rethrows -> R { + var context = ContextManager.local.currentContext + context.learningPhase = learningPhase + return try withContext(context, body) +} + +/// A manager that maintains and provides safe access to thread-local `Context` values. +private final class ContextManager { + var contextStack: [Context] = [Context()] + + /// The data key for the singleton `Context` in the current thread. + static let key: pthread_key_t = { + var key = pthread_key_t() + pthread_key_create(&key) { obj in +#if !(os(macOS) || os(iOS) || os(watchOS) || os(tvOS)) + let obj = obj! +#endif + Unmanaged.fromOpaque(obj).release() + } + return key + }() + + /// The thread-local singleton. + static var local: ContextManager { + if let address = pthread_getspecific(key) { + return Unmanaged.fromOpaque(address).takeUnretainedValue() + } + let context = ContextManager() + pthread_setspecific(key, Unmanaged.passRetained(context).toOpaque()) + return context + } + + /// Pushes the given context to the context stack. + func push(_ context: Context) { + contextStack.append(context) + } + + /// Pops a context out of a stack. + /// + /// - Precondition: The context stack must contain more than `1` contexts. + func popContext() { + assert(contextStack.count > 1, + "Internal error: Only 1 context is available. Popping is not allowed.") + contextStack.removeLast() + } + + /// The most recent context. + var currentContext: Context { + _read { + assert(!contextStack.isEmpty, "Internal error: No contexts exist.") + yield contextStack[contextStack.endIndex - 1] + } + _modify { + assert(!contextStack.isEmpty, "Internal error: No contexts exist.") + yield &contextStack[contextStack.endIndex - 1] + } + } +} diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 3b4746a6b..43894153b 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -16,38 +16,12 @@ @_exported import TensorFlow #endif -/// A value that indicates either a training phase or an inference phase for a layer. -public enum LearningPhase { - case training - case inference -} - -/// A context that stores contextual information used for the application of layers. -open class Context { - /// The current learning phase. - public var learningPhase: LearningPhase - - /// Creates a context. - /// - /// - Parameter learningPhase: The current learning phase. - public required init(learningPhase: LearningPhase) { - self.learningPhase = learningPhase - } - - /// Creates a context by copying all information from an existing context. - /// - /// - Parameter context: The existing context to copy from. - public required init(_ other: Context) { - self.learningPhase = other.learningPhase - } -} - /// A neural network layer. /// /// Types that conform to `Layer` represent functions that map inputs to outputs. They may have an /// internal state represented by parameters, such as weight tensors. /// -/// `Layer` instances define a differentiable `applied(to:in:)` method for mapping inputs to +/// `Layer` instances define a differentiable `applied(to:)` method for mapping inputs to /// outputs. public protocol Layer: Differentiable & KeyPathIterable where AllDifferentiableVariables: KeyPathIterable { @@ -60,28 +34,33 @@ public protocol Layer: Differentiable & KeyPathIterable /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - func applied(to input: Input, in context: Context) -> Output + func applied(to input: Input) -> Output } public extension Layer { - @available(*, deprecated, - message: "Switch to 'applied(to:in:)' for training, or 'inferring(from:)' for inference") - func applied(to input: Input) -> Output { - return inferring(from: input) - } - /// Returns the inference output obtained from applying the layer to the given input. /// /// - Parameter input: The input to the layer. /// - Returns: The inference output. @differentiable func inferring(from input: Input) -> Output { - let context = Context(learningPhase: .inference) - return applied(to: input, in: context) + return withLearningPhase(.inference) { + applied(to: input) + } + } + + // TODO(rxwei): Remove this custom VJP once differentiation supports currying. + @differentiating(inferring(from:)) + @usableFromInline + internal func _vjpInferring(from input: Input) + -> (value: Output, pullback: (Output.CotangentVector) + -> (CotangentVector, Input.CotangentVector)) { + return withLearningPhase(.inference) { + let (output, pullback) = appliedForBackpropagation(to: input) + return (output, { v in pullback(v) }) + } } typealias Backpropagator = (_ direction: Output.CotangentVector) @@ -94,10 +73,10 @@ public extension Layer { /// - Returns: A tuple containing the output and the backpropagation function. The /// backpropagation function (a.k.a. backpropagator) takes a direction vector and returns the /// gradients at the layer and at the input, respectively. - func appliedForBackpropagation(to input: Input, in context: Context) + func appliedForBackpropagation(to input: Input) -> (output: Output, backpropagator: Backpropagator) { let (out, pullback) = valueWithPullback(at: input) { layer, input in - return layer.applied(to: input, in: context) + return layer.applied(to: input) } return (out, pullback) } @@ -108,49 +87,36 @@ public extension Differentiable { /// except that the first layer's input is `self`. /// /// - Parameters: - /// - context: The context that stores contextual information used for the application of - /// layers. /// - l1: The first layer. /// - l2: The second layer. /// - Returns: The final layer's output after sequential application. @differentiable - func sequenced( - in context: Context, through l1: L1, _ l2: L2) - -> L2.Output - where L1.Input == Self, - L1.Output == L2.Input { - let o1 = l1.applied(to: self, in: context) - return l2.applied(to: o1, in: context) + func sequenced(through l1: L1, _ l2: L2) -> L2.Output + where L1.Input == Self, L1.Output == L2.Input { + let o1 = l1.applied(to: self) + return l2.applied(to: o1) } /// Returns the output computed by applying a sequence of layers to the previous layer's output, /// except that the first layer's input is `self`. /// /// - Parameters: - /// - context: The context that stores contextual information used for the application of - /// layers. /// - l1: The first layer. /// - l2: The second layer. /// - l3: The third layer. /// - Returns: The final layer's output after sequential application. @differentiable - func sequenced( - in context: Context, through l1: L1, _ l2: L2, _ l3: L3) - -> L3.Output - where L1.Input == Self, - L1.Output == L2.Input, - L2.Output == L3.Input { - let o1 = l1.applied(to: self, in: context) - let o2 = l2.applied(to: o1, in: context) - return l3.applied(to: o2, in: context) + func sequenced(through l1: L1, _ l2: L2, _ l3: L3) -> L3.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input { + let o1 = l1.applied(to: self) + let o2 = l2.applied(to: o1) + return l3.applied(to: o2) } /// Returns the output computed by applying a sequence of layers to the previous layer's output, /// except that the first layer's input is `self`. /// /// - Parameters: - /// - context: The context that stores contextual information used for the application of - /// layers. /// - l1: The first layer. /// - l2: The second layer. /// - l3: The third layer. @@ -158,24 +124,20 @@ public extension Differentiable { /// - Returns: The final layer's output after sequential application. @differentiable func sequenced( - in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4) - -> L4.Output - where L1.Input == Self, - L1.Output == L2.Input, - L2.Output == L3.Input, - L3.Output == L4.Input { - let o1 = l1.applied(to: self, in: context) - let o2 = l2.applied(to: o1, in: context) - let o3 = l3.applied(to: o2, in: context) - return l4.applied(to: o3, in: context) + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4 + ) -> L4.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, + L3.Output == L4.Input { + let o1 = l1.applied(to: self) + let o2 = l2.applied(to: o1) + let o3 = l3.applied(to: o2) + return l4.applied(to: o3) } /// Returns the output computed by applying a sequence of layers to the previous layer's output, /// except that the first layer's input is `self`. /// /// - Parameters: - /// - context: The context that stores contextual information used for the application of - /// layers. /// - l1: The first layer. /// - l2: The second layer. /// - l3: The third layer. @@ -184,26 +146,21 @@ public extension Differentiable { /// - Returns: The final layer's output after sequential application. @differentiable func sequenced( - in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5) - -> L5.Output - where L1.Input == Self, - L1.Output == L2.Input, - L2.Output == L3.Input, - L3.Output == L4.Input, - L4.Output == L5.Input { - let o1 = l1.applied(to: self, in: context) - let o2 = l2.applied(to: o1, in: context) - let o3 = l3.applied(to: o2, in: context) - let o4 = l4.applied(to: o3, in: context) - return l5.applied(to: o4, in: context) + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5 + ) -> L5.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, + L4.Output == L5.Input { + let o1 = l1.applied(to: self) + let o2 = l2.applied(to: o1) + let o3 = l3.applied(to: o2) + let o4 = l4.applied(to: o3) + return l5.applied(to: o4) } /// Returns the output computed by applying a sequence of layers to the previous layer's output, /// except that the first layer's input is `self`. /// /// - Parameters: - /// - context: The context that stores contextual information used for the application of - /// layers. /// - l1: The first layer. /// - l2: The second layer. /// - l3: The third layer. @@ -213,20 +170,16 @@ public extension Differentiable { /// - Returns: The final layer's output after sequential application. @differentiable func sequenced( - in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6) - -> L6.Output - where L1.Input == Self, - L1.Output == L2.Input, - L2.Output == L3.Input, - L3.Output == L4.Input, - L4.Output == L5.Input, - L5.Output == L6.Input { - let o1 = l1.applied(to: self, in: context) - let o2 = l2.applied(to: o1, in: context) - let o3 = l3.applied(to: o2, in: context) - let o4 = l4.applied(to: o3, in: context) - let o5 = l5.applied(to: o4, in: context) - return l6.applied(to: o5, in: context) + through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6 + ) -> L6.Output + where L1.Input == Self, L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, + L4.Output == L5.Input, L5.Output == L6.Input { + let o1 = l1.applied(to: self) + let o2 = l2.applied(to: o1) + let o3 = l3.applied(to: o2) + let o4 = l4.applied(to: o3) + let o5 = l5.applied(to: o4) + return l6.applied(to: o5) } } @@ -268,11 +221,9 @@ public struct Dense: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return activation(matmul(input, weight) + bias) } } @@ -379,11 +330,9 @@ public struct Conv1D: Layer { /// /// - Parameters: /// - input: The input to the layer `[batchCount, width, inputChannels]`. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output `[batchCount, newWidth, outputChannels]`. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let conv2D = input.expandingShape(at: 1).convolved2D( withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding) return activation(conv2D.squeezingShape(at: 1) + bias) @@ -500,11 +449,9 @@ public struct Conv2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return activation(input.convolved2D(withFilter: filter, strides: (1, strides.0, strides.1, 1), padding: padding) + bias) @@ -623,11 +570,9 @@ public struct TransposedConv2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let batchSize = input.shape[0] let w = (input.shape[1] - (1 * paddingIndex)) * strides.0 + (filter.shape[0] * paddingIndex) let h = (input.shape[2] - (1 * paddingIndex)) * strides.1 + (filter.shape[1] * paddingIndex) @@ -780,12 +725,10 @@ public struct BatchNorm: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. - @differentiable(vjp: _vjpApplied(to:in:)) - public func applied(to input: Tensor, in context: Context) -> Tensor { - switch context.learningPhase { + @differentiable(vjp: _vjpApplied(to:)) + public func applied(to input: Tensor) -> Tensor { + switch Context.local.learningPhase { case .training: return applyingTraining(to: input) case .inference: @@ -794,10 +737,10 @@ public struct BatchNorm: Layer { } @usableFromInline - func _vjpApplied(to input: Tensor, in context: Context) -> + func _vjpApplied(to input: Tensor) -> (Tensor, (Tensor) -> (BatchNorm.CotangentVector, Tensor)) { - switch context.learningPhase { + switch Context.local.learningPhase { case .training: return valueWithPullback(at: input) { $0.applyingTraining(to: $1) @@ -860,11 +803,9 @@ public struct MaxPool1D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.expandingShape(at: 1).maxPooled( kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) @@ -911,11 +852,9 @@ public struct MaxPool2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.maxPooled( kernelSize: poolSize, strides: strides, padding: padding) } @@ -951,11 +890,9 @@ public struct AvgPool1D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.expandingShape(at: 1).averagePooled( kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) @@ -1002,11 +939,9 @@ public struct AvgPool2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.averagePooled(kernelSize: poolSize, strides: strides, padding: padding) } } @@ -1022,11 +957,9 @@ public struct GlobalAveragePooling1D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.mean(alongAxes: 1).reshaped(to: [input.shape[0], input.shape[2]]) } } @@ -1041,11 +974,9 @@ public struct GlobalAveragePooling2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.mean(alongAxes: [1, 2]).reshaped(to: [input.shape[0], input.shape[3]]) } } @@ -1060,11 +991,9 @@ public struct GlobalAveragePooling3D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.mean(alongAxes: [1, 2, 3]).reshaped(to: [input.shape[0], input.shape[4]]) } } @@ -1117,11 +1046,9 @@ public struct LayerNorm: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let mean = input.mean(alongAxes: axis) let variance = input.variance(alongAxes: axis) let inv = rsqrt(variance + epsilon) * scale @@ -1169,12 +1096,10 @@ public struct Dropout: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. - @differentiable(vjp: _vjpApplied(to:in:)) - public func applied(to input: Tensor, in context: Context) -> Tensor { - switch context.learningPhase { + @differentiable(vjp: _vjpApplied(to:)) + public func applied(to input: Tensor) -> Tensor { + switch Context.local.learningPhase { case .training: return applyingTraining(to: input) case .inference: @@ -1183,10 +1108,10 @@ public struct Dropout: Layer { } @usableFromInline - func _vjpApplied(to input: Tensor, in context: Context) -> + func _vjpApplied(to input: Tensor) -> (Tensor, (Tensor) -> (Dropout.CotangentVector, Tensor)) { - switch context.learningPhase { + switch Context.local.learningPhase { case .training: return valueWithPullback(at: input) { $0.applyingTraining(to: $1) @@ -1215,11 +1140,9 @@ public struct UpSampling1D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let shape = input.shape let (batchSize, timesteps, channels) = (shape[0], shape[1], shape[2]) let scaleOnes = Tensor(ones: [1, 1, size, 1]) @@ -1244,11 +1167,9 @@ public struct UpSampling2D: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let shape = input.shape let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3]) let scaleOnes = Tensor(ones: [1, 1, size, 1, size, 1]) @@ -1269,11 +1190,9 @@ public struct Flatten: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { let batchSize = input.shape[0] let remaining = input.shape[1..: Layer { /// /// - Parameters: /// - input: The input to the layer. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - public func applied(to input: Tensor, in _: Context) -> Tensor { + public func applied(to input: Tensor) -> Tensor { return input.reshaped(toShape: shape) } } @@ -1366,15 +1283,9 @@ public extension RNNCell { /// - Parameters: /// - timeStepInput: The input at the current time step. /// - previousState: The previous state of the RNN cell. - /// - context: The contextual information for the layer application, e.g. the current learning - /// phase. /// - Returns: The output. @differentiable - func applied( - to input: TimeStepInput, - state: State, - in context: Context - ) -> RNNCellOutput { - return applied(to: RNNCellInput(input: input, state: state), in: context) + func applied(to input: TimeStepInput, state: State) -> RNNCellOutput { + return applied(to: RNNCellInput(input: input, state: state)) } } diff --git a/Tests/DeepLearningTests/SequentialTests.swift b/Tests/DeepLearningTests/SequentialTests.swift index 2aea30a07..27669c4fb 100644 --- a/Tests/DeepLearningTests/SequentialTests.swift +++ b/Tests/DeepLearningTests/SequentialTests.swift @@ -24,23 +24,24 @@ final class SequentialTests: XCTestCase { seed: (0xfeffeffe, 0xfffe)) @differentiable - func applied(to input: Tensor, in context: Context) -> Tensor { - return input.sequenced(in: context, through: dense1, dense2) + func applied(to input: Tensor) -> Tensor { + return input.sequenced(through: dense1, dense2) } } var model = Model() let optimizer = SGD(for: model, learningRate: 0.02, scalarType: Float.self) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [0, 1, 1, 0] - let context = Context(learningPhase: .training) + Context.local.learningPhase = .training for _ in 0..<1000 { let 𝛁model = model.gradient { model -> Tensor in - let ŷ = model.applied(to: x, in: context) + let ŷ = model.applied(to: x) return meanSquaredError(predicted: ŷ, expected: y) } optimizer.update(&model.allDifferentiableVariables, along: 𝛁model) } - print(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]])) + XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]), + [[ 0.491493], [ 0.5063815], [0.49968663], [0.50133944]]) } static var allTests = [ diff --git a/Tests/DeepLearningTests/TrivialModelTests.swift b/Tests/DeepLearningTests/TrivialModelTests.swift index 9a29db7ed..63772b270 100644 --- a/Tests/DeepLearningTests/TrivialModelTests.swift +++ b/Tests/DeepLearningTests/TrivialModelTests.swift @@ -34,9 +34,9 @@ final class TrivialModelTests: XCTestCase { ) } @differentiable - func applied(to input: Tensor, in context: Context) -> Tensor { - let h1 = l1.applied(to: input, in: context) - return l2.applied(to: h1, in: context) + func applied(to input: Tensor) -> Tensor { + let h1 = l1.applied(to: input) + return l2.applied(to: h1) } } var classifier = Classifier(hiddenSize: 4) @@ -44,10 +44,10 @@ final class TrivialModelTests: XCTestCase { let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [[0], [1], [1], [0]] - let trainingContext = Context(learningPhase: .training) + Context.local.learningPhase = .training for _ in 0..<3000 { let 𝛁model = classifier.gradient { classifier -> Tensor in - let ŷ = classifier.applied(to: x, in: trainingContext) + let ŷ = classifier.applied(to: x) return meanSquaredError(predicted: ŷ, expected: y) } optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model) From 4506174289bb3469343a66bf9895c3588fb127ad Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sun, 14 Apr 2019 23:13:13 -0700 Subject: [PATCH 2/2] Add tests. --- Tests/DeepLearningTests/ContextTests.swift | 56 ++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 Tests/DeepLearningTests/ContextTests.swift diff --git a/Tests/DeepLearningTests/ContextTests.swift b/Tests/DeepLearningTests/ContextTests.swift new file mode 100644 index 000000000..41b49a81e --- /dev/null +++ b/Tests/DeepLearningTests/ContextTests.swift @@ -0,0 +1,56 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +import Dispatch +@testable import DeepLearning + +final class ContextTests: XCTestCase { + func testDropout() { + Context.local.learningPhase = .inference + let dropout = Dropout(probability: 0.5) + let x = Tensor(repeating: 1.0, shape: [5, 5]) + XCTAssertEqual(dropout.applied(to: x), x) + withLearningPhase(.inference) { + XCTAssertEqual(dropout.applied(to: x), x) + withLearningPhase(.training) { + XCTAssertNotEqual(dropout.applied(to: x), x) + } + XCTAssertEqual(dropout.applied(to: x), x) + } + XCTAssertEqual(dropout.applied(to: x), x) + } + + func testMultithreadedDropout() { + let dropout = Dropout(probability: 0.5) + let x = Tensor(repeating: 1.0, shape: [5, 5]) + Context.local.learningPhase = .inference + DispatchQueue.concurrentPerform(iterations: 10) { i in + if i.isMultiple(of: 2) { + XCTAssertEqual(dropout.applied(to: x), x) + withLearningPhase(.training) { + XCTAssertNotEqual(dropout.applied(to: x), x) + } + XCTAssertEqual(dropout.applied(to: x), x) + } else { + XCTAssertEqual(dropout.applied(to: x), x) + } + } + } + + static var allTests = [ + ("testDropout", testDropout), + ("testMultithreadedDropout", testMultithreadedDropout) + ] +}