diff --git a/Sources/DeepLearning/Optimizer.swift b/Sources/DeepLearning/Optimizer.swift index 4a3186ea7..7daa5c73f 100644 --- a/Sources/DeepLearning/Optimizer.swift +++ b/Sources/DeepLearning/Optimizer.swift @@ -35,13 +35,13 @@ public class Adam: Optimizer public let decay: Scalar public init( + for _: __shared Model, learningRate: Scalar = 1e-3, beta1: Scalar = 0.9, beta2: Scalar = 0.999, epsilon: Scalar = 1e-8, decay: Scalar = 0, - modelType: Model.Type = Model.self, - scalarType: Scalar.Type = Scalar.self + scalarType: Scalar.Type ) { precondition(learningRate >= 0, "Learning rate must be non-negative") precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1") @@ -84,12 +84,12 @@ public class RMSProp: Optimizer public let decay: Scalar public init( + for _: __shared Model, learningRate: Scalar = 0.001, rho: Scalar = 0.9, epsilon: Scalar = 1e-8, decay: Scalar = 0, - modelType: Model.Type = Model.self, - scalarType: Scalar.Type = Scalar.self + scalarType: Scalar.Type ) { precondition(learningRate >= 0, "Learning rate must be non-negative") precondition(rho >= 0, "Rho must be non-negative") @@ -125,12 +125,12 @@ public class SGD: Optimizer public let nesterov: Bool public init( + for _: __shared Model, learningRate: Scalar = 0.01, momentum: Scalar = 0, decay: Scalar = 0, nesterov: Bool = false, - modelType: Model.Type = Model.self, - scalarType: Scalar.Type = Scalar.self + scalarType: Scalar.Type ) { precondition(learningRate >= 0, "Learning rate must be non-negative") precondition(momentum >= 0, "Momentum must be non-negative") @@ -171,7 +171,7 @@ public class RiemannSGD: Optimizer public init( learningRate: Scalar, modelType: Model.Type = Model.self, - scalarType: Scalar.Type = Scalar.self + scalarType: Scalar.Type ) { self.learningRate = learningRate } diff --git a/Tests/DeepLearningTests/SequentialTests.swift b/Tests/DeepLearningTests/SequentialTests.swift index 6591fe320..b8bb932f8 100644 --- a/Tests/DeepLearningTests/SequentialTests.swift +++ b/Tests/DeepLearningTests/SequentialTests.swift @@ -27,7 +27,7 @@ final class SequentialTests: XCTestCase { } } var model = Model() - let optimizer = SGD(learningRate: 0.02, modelType: type(of: model), scalarType: Float.self) + let optimizer = SGD(for: model, learningRate: 0.02, scalarType: Float.self) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [0, 1, 1, 0] let context = Context(learningPhase: .training) diff --git a/Tests/DeepLearningTests/TrivialModelTests.swift b/Tests/DeepLearningTests/TrivialModelTests.swift index 659e08b46..7882c76b9 100644 --- a/Tests/DeepLearningTests/TrivialModelTests.swift +++ b/Tests/DeepLearningTests/TrivialModelTests.swift @@ -40,8 +40,8 @@ final class TrivialModelTests: XCTestCase { return l2.applied(to: h1, in: context) } } - let optimizer = SGD(learningRate: 0.02) var classifier = Classifier(hiddenSize: 4) + let optimizer = SGD(for: classifier, learningRate: 0.02, scalarType: Float.self) let x: Tensor = [[0, 0], [0, 1], [1, 0], [1, 1]] let y: Tensor = [[0], [1], [1], [0]]