Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Model: Layer {
var layer2 = Dense<Float>(inputSize: hiddenSize, outputSize: hiddenSize, activation: relu)
var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3, activation: identity)

@differentiable(wrt: (self, input))
@differentiable
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
return input.sequenced(in: context, through: layer1, layer2, layer3)
}
Expand Down
38 changes: 19 additions & 19 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public protocol Layer: Differentiable & KeyPathIterable
/// - context: The contextual informance for the layer application, e.g. the current learning
/// phase.
/// - Returns: The output.
@differentiable(wrt: (self, input))
@differentiable
func applied(to input: Input, in context: Context) -> Output
}

Expand All @@ -78,7 +78,7 @@ public extension Layer {
///
/// - Parameter input: The input to the layer.
/// - Returns: The inference output.
@differentiable(wrt: (self, input))
@differentiable
func inferring(from input: Input) -> Output {
let context = Context(learningPhase: .inference)
return applied(to: input, in: context)
Expand All @@ -104,7 +104,7 @@ public extension Layer {

/// Adds helpers for standard feed-forward, sequential models.
public extension Differentiable {
@differentiable(wrt: (self, l1, l2))
@differentiable
func sequenced<L1: Layer, L2: Layer>(
in context: Context, through l1: L1, _ l2: L2)
-> L2.Output
Expand All @@ -114,7 +114,7 @@ public extension Differentiable {
return l2.applied(to: o1, in: context)
}

@differentiable(wrt: (self, l1, l2, l3))
@differentiable
func sequenced<L1: Layer, L2: Layer, L3: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3)
-> L3.Output
Expand All @@ -126,7 +126,7 @@ public extension Differentiable {
return l3.applied(to: o2, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4))
@differentiable
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4)
-> L4.Output
Expand All @@ -140,7 +140,7 @@ public extension Differentiable {
return l4.applied(to: o3, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4, l5))
@differentiable
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5)
-> L5.Output
Expand All @@ -156,7 +156,7 @@ public extension Differentiable {
return l5.applied(to: o4, in: context)
}

@differentiable(wrt: (self, l1, l2, l3, l4, l5, l6))
@differentiable
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer>(
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6)
-> L6.Output
Expand Down Expand Up @@ -196,7 +196,7 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
@noDerivative public let activation: Activation

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
return activation(matmul(input, weight) + bias)
}
Expand Down Expand Up @@ -230,7 +230,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
@noDerivative public let strides: (Int32, Int32)
@noDerivative public let padding: Padding

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
return activation(input.convolved2D(withFilter: filter,
strides: (1, strides.0, strides.1, 1),
Expand Down Expand Up @@ -286,7 +286,7 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
/// The running variance.
@noDerivative public let runningVariance: Parameter<Scalar>

@differentiable(wrt: (self, input))
@differentiable
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let positiveAxis = (input.rank + axis) % input.rank
let mean = input.mean(alongAxes: [0, positiveAxis])
Expand All @@ -298,13 +298,13 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
return (input - mean) * inv + offset
}

@differentiable(wrt: (self, input))
@differentiable
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
let inv = rsqrt(runningVariance.value + epsilon) * scale
return (input - runningMean.value) * inv + offset
}

@differentiable(wrt: (self, input), vjp: _vjpApplied(to:in:))
@differentiable(vjp: _vjpApplied(to:in:))
public func applied(to input: Tensor<Scalar>, in context: Context) -> Tensor<Scalar> {
switch context.learningPhase {
case .training:
Expand Down Expand Up @@ -360,7 +360,7 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
self.padding = padding
}

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
return input.maxPooled(
kernelSize: poolSize, strides: strides, padding: padding)
Expand All @@ -383,7 +383,7 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
self.padding = padding
}

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
return input.averagePooled(
kernelSize: poolSize, strides: strides, padding: padding)
Expand All @@ -410,7 +410,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
self.epsilon = epsilon
}

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
let mean = input.mean(alongAxes: axis)
let variance = input.variance(alongAxes: axis)
Expand Down Expand Up @@ -439,17 +439,17 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: Layer
self.probability = probability
}

@differentiable(wrt: (self, input))
@differentiable
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.droppingOut(probability: probability)
}

@differentiable(wrt: (self, input))
@differentiable
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
return input
}

@differentiable(wrt: (self, input), vjp: _vjpApplied(to:in:))
@differentiable(vjp: _vjpApplied(to:in:))
public func applied(to input: Tensor<Scalar>, in context: Context) -> Tensor<Scalar> {
switch context.learningPhase {
case .training:
Expand Down Expand Up @@ -484,7 +484,7 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
self.size = size
}

@differentiable(wrt: (self, input))
@differentiable
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
let shape = input.shape
let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3])
Expand Down
2 changes: 1 addition & 1 deletion Tests/DeepLearningTests/SequentialTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ final class SequentialTests: XCTestCase {
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu)

@differentiable(wrt: (self, input))
@differentiable
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
return input.sequenced(in: context, through: dense1, dense2)
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/DeepLearningTests/TrivialModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ final class TrivialModelTests: XCTestCase {
generator: &Classifier.generator
)
}
@differentiable(wrt: (self, input))
@differentiable
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
let h1 = l1.applied(to: input, in: context)
return l2.applied(to: h1, in: context)
Expand Down