diff --git a/Sources/TensorFlow/Operators/Math.swift b/Sources/TensorFlow/Operators/Math.swift index a09e6465a..941c2b16b 100644 --- a/Sources/TensorFlow/Operators/Math.swift +++ b/Sources/TensorFlow/Operators/Math.swift @@ -32,6 +32,8 @@ func pow(_ x: T, _ y: T) -> T { //===------------------------------------------------------------------------------------------===// extension Tensor: VectorProtocol where Scalar: Numeric { + public typealias VectorSpaceScalar = Scalar + /// Multiplies the scalar with every scalar of the tensor and produces the product. @inlinable @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) diff --git a/Sources/TensorFlow/Optimizer.swift b/Sources/TensorFlow/Optimizer.swift index 6635b4533..e7446604e 100644 --- a/Sources/TensorFlow/Optimizer.swift +++ b/Sources/TensorFlow/Optimizer.swift @@ -274,7 +274,7 @@ public class SGD: Optimizer /// A Riemann manifold stochastic gradient descent (SGD) optimizer. public class RiemannSGD: Optimizer - where Model.TangentVector: VectorProtocol, Model.TangentVector.Scalar == Scalar { + where Model.TangentVector: VectorProtocol, Model.TangentVector.VectorSpaceScalar == Scalar { /// The learning rate. public var learningRate: Scalar