Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,80 @@ public extension Layer {
}
}

/// Adds helpers for standard feed-forward, sequential models.
public extension Differentiable {
@differentiable(wrt: (self, l1, l2))
func sequenced<L1: Layer, L2: Layer>(
in context: Context, through l1: L1, _ l2: L2)
-> L2.Output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to leave comments about indentation, but I thought "it's gonna be gyb'd anyway" :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't have my editor set up to do the indentation automatically for me. :-( Sorry!

where L1.Input == Self,
L1.Output == L2.Input {
let o1 = l1.applied(to: self, in: context)
return l2.applied(to: o1, in: context)
}

@differentiable(wrt: (self, l1, l2, l3))
func sequenced<L1: Layer, L2: Layer, L3: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3)
-> L3.Output
where L1.Input == Self,
L1.Output == L2.Input,
L2.Output == L3.Input {
let o1 = l1.applied(to: self, in: context)
let o2 = l2.applied(to: o1, in: context)
return l3.applied(to: o2, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4))
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4)
-> L4.Output
where L1.Input == Self,
L1.Output == L2.Input,
L2.Output == L3.Input,
L3.Output == L4.Input {
let o1 = l1.applied(to: self, in: context)
let o2 = l2.applied(to: o1, in: context)
let o3 = l3.applied(to: o2, in: context)
return l4.applied(to: o3, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4, l5))
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5)
-> L5.Output
where L1.Input == Self,
L1.Output == L2.Input,
L2.Output == L3.Input,
L3.Output == L4.Input,
L4.Output == L5.Input {
let o1 = l1.applied(to: self, in: context)
let o2 = l2.applied(to: o1, in: context)
let o3 = l3.applied(to: o2, in: context)
let o4 = l4.applied(to: o3, in: context)
return l5.applied(to: o4, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4, l5, l6))
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6)
-> L6.Output
where L1.Input == Self,
L1.Output == L2.Input,
L2.Output == L3.Input,
L3.Output == L4.Input,
L4.Output == L5.Input,
L5.Output == L6.Input {
let o1 = l1.applied(to: self, in: context)
let o2 = l2.applied(to: o1, in: context)
let o3 = l3.applied(to: o2, in: context)
let o4 = l4.applied(to: o3, in: context)
let o5 = l5.applied(to: o4, in: context)
return l6.applied(to: o5, in: context)
}
}


/// A mutable, shareable, owning reference to a tensor.
public final class Parameter<Scalar: TensorFlowScalar> {
public var value: Tensor<Scalar>
Expand Down
47 changes: 47 additions & 0 deletions Tests/DeepLearningTests/SequentialTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
@testable import DeepLearning

final class SequentialTests: XCTestCase {
func testSequential() {
struct Model: Layer {
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu)

@differentiable(wrt: (self, input))
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
return input.sequenced(in: context, through: dense1, dense2)
}
}
var model = Model()
let optimizer = SGD(learningRate: 0.02, modelType: type(of: model), scalarType: Float.self)
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
let y: Tensor<Float> = [0, 1, 1, 0]
let context = Context(learningPhase: .training)
for _ in 0..<1000 {
let 𝛁model = model.gradient { model -> Tensor<Float> in
let ŷ = model.applied(to: x, in: context)
return meanSquaredError(predicted: ŷ, expected: y)
}
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
}
print(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]))
}

static var allTests = [
("testSequential", testSequential)
]
}
1 change: 1 addition & 0 deletions Tests/DeepLearningTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public func allTests() -> [XCTestCaseEntry] {
return [
testCase(PRNGTests.allTests),
testCase(TrivialModelTests.allTests),
testCase(SequentialTests.allTests),
]
}
#endif