diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index 265746ce58448..a3e72474a3b7e 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -184,36 +184,31 @@ extension Array where Element: Differentiable { func _vjpSubscript(index: Int) -> ( value: Element, pullback: (Element.TangentVector) -> TangentVector ) { - func pullback(_ gradientIn: Element.TangentVector) -> TangentVector { - var gradientOut = [Element.TangentVector]( + func pullback(_ v: Element.TangentVector) -> TangentVector { + var dSelf = [Element.TangentVector]( repeating: .zero, count: count) - gradientOut[index] = gradientIn - return TangentVector(gradientOut) + dSelf[index] = v + return TangentVector(dSelf) } return (self[index], pullback) } @usableFromInline @derivative(of: +) - static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> ( - value: [Element], + static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> ( + value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { - func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector) - { + func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { precondition( - gradientIn.base.count == lhs.count + rhs.count, + v.base.count == lhs.count + rhs.count, "+ should receive gradient with count equal to sum of operand " - + "counts, but counts are: gradient \(gradientIn.base.count), " + + "counts, but counts are: gradient \(v.base.count), " + "lhs \(lhs.count), rhs \(rhs.count)") return ( - TangentVector( - [Element.TangentVector]( - gradientIn.base[0..( + _ body: @differentiable (Element) -> Result + ) -> [Result] { + map(body) + } + + @inlinable + @derivative(of: differentiableMap) + internal func _vjpDifferentiableMap( + _ body: @differentiable (Element) -> Result + ) -> ( + value: [Result], + pullback: (Array.TangentVector) -> Array.TangentVector + ) { + var values: [Result] = [] + var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = [] + for x in self { + let (y, pb) = valueWithPullback(at: x, in: body) + values.append(y) + pullbacks.append(pb) + } + func pullback(_ tans: Array.TangentVector) -> Array.TangentVector { + .init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) }) + } + return (value: values, pullback: pullback) + } +} + extension Array where Element: Differentiable { @inlinable @differentiable(wrt: (self, initialResult)) @@ -336,34 +362,3 @@ extension Array where Element: Differentiable { ) } } - -extension Array where Element: Differentiable { - @inlinable - @differentiable(wrt: self) - public func differentiableMap( - _ body: @differentiable (Element) -> Result - ) -> [Result] { - map(body) - } - - @inlinable - @derivative(of: differentiableMap) - internal func _vjpDifferentiableMap( - _ body: @differentiable (Element) -> Result - ) -> ( - value: [Result], - pullback: (Array.TangentVector) -> Array.TangentVector - ) { - var values: [Result] = [] - var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = [] - for x in self { - let (y, pb) = valueWithPullback(at: x, in: body) - values.append(y) - pullbacks.append(pb) - } - func pullback(_ tans: Array.TangentVector) -> Array.TangentVector { - .init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) }) - } - return (value: values, pullback: pullback) - } -} diff --git a/test/AutoDiff/stdlib/array.swift b/test/AutoDiff/validation-test/array.swift similarity index 100% rename from test/AutoDiff/stdlib/array.swift rename to test/AutoDiff/validation-test/array.swift