From 8f847cf8e08c60f5c1541229b561bf84a20c6bd3 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Wed, 30 Oct 2019 14:46:17 -0700 Subject: [PATCH] add 'Element: Differentiable' constraint to DifferentiableView extensions --- Sources/TensorFlow/StdlibExtensions.swift | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/Sources/TensorFlow/StdlibExtensions.swift b/Sources/TensorFlow/StdlibExtensions.swift index 58c70bfe7..1fababc12 100644 --- a/Sources/TensorFlow/StdlibExtensions.swift +++ b/Sources/TensorFlow/StdlibExtensions.swift @@ -105,7 +105,9 @@ extension Array: ElementaryFunctions where Element: ElementaryFunctions { // MARK: - Array derivative extensions -extension Array.DifferentiableView: ElementaryFunctions where Element: ElementaryFunctions { +extension Array.DifferentiableView: ElementaryFunctions + where Element: Differentiable & ElementaryFunctions +{ /// The square root of `x`. /// /// For real types, if `x` is negative the result is `.nan`. For complex @@ -192,7 +194,9 @@ extension Array.DifferentiableView: ElementaryFunctions where Element: Elementar } extension Array.DifferentiableView - : MutableCollection, RandomAccessCollection, RangeReplaceableCollection { + : MutableCollection, RandomAccessCollection, RangeReplaceableCollection + where Element: Differentiable +{ public typealias Element = Array.Element public typealias Index = Array.Index public typealias Indices = Array.Indices @@ -214,7 +218,9 @@ extension Array.DifferentiableView public init() { self.init(.init()) } } -extension Array.DifferentiableView: VectorProtocol where Element: VectorProtocol { +extension Array.DifferentiableView: VectorProtocol + where Element: Differentiable & VectorProtocol +{ public typealias VectorSpaceScalar = Element.VectorSpaceScalar public func adding(_ x: Element.VectorSpaceScalar) -> Array.DifferentiableView { @@ -249,7 +255,8 @@ extension Array.DifferentiableView: VectorProtocol where Element: VectorProtocol } extension Array.DifferentiableView: PointwiseMultiplicative - where Element: PointwiseMultiplicative { + where Element: Differentiable & PointwiseMultiplicative +{ // FIXME: `one` should probably be removed from the protocol. `Array` cannot represent `one`. public static var one: Self { fatalError("One is not array-representable")