Skip to content

Commit

Permalink
[AutoDiff] NFC: garden array differentiation.
Browse files Browse the repository at this point in the history
Use consistent variable naming. Reorganize code.
  • Loading branch information
dan-zheng committed May 14, 2020
1 parent 8829a14 commit 2e690e2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions stdlib/public/Differentiation/ArrayDifferentiation.swift
Expand Up @@ -184,36 +184,31 @@ extension Array where Element: Differentiable {
func _vjpSubscript(index: Int) -> (
value: Element, pullback: (Element.TangentVector) -> TangentVector
) {
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
var gradientOut = [Element.TangentVector](
func pullback(_ v: Element.TangentVector) -> TangentVector {
var dSelf = [Element.TangentVector](
repeating: .zero,
count: count)
gradientOut[index] = gradientIn
return TangentVector(gradientOut)
dSelf[index] = v
return TangentVector(dSelf)
}
return (self[index], pullback)
}

@usableFromInline
@derivative(of: +)
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
value: [Element],
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
value: Self,
pullback: (TangentVector) -> (TangentVector, TangentVector)
) {
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
{
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
precondition(
gradientIn.base.count == lhs.count + rhs.count,
v.base.count == lhs.count + rhs.count,
"+ should receive gradient with count equal to sum of operand "
+ "counts, but counts are: gradient \(gradientIn.base.count), "
+ "counts, but counts are: gradient \(v.base.count), "
+ "lhs \(lhs.count), rhs \(rhs.count)")
return (
TangentVector(
[Element.TangentVector](
gradientIn.base[0..<lhs.count])),
TangentVector(
[Element.TangentVector](
gradientIn.base[lhs.count...]))
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),
TangentVector([Element.TangentVector](v.base[lhs.count...]))
)
}
return (lhs + rhs, pullback)
Expand Down Expand Up @@ -288,6 +283,37 @@ extension Array where Element: Differentiable {
// Differentiable higher order functions for collections
//===----------------------------------------------------------------------===//

extension Array where Element: Differentiable {
@inlinable
@differentiable(wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> [Result] {
map(body)
}

@inlinable
@derivative(of: differentiableMap)
internal func _vjpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> (
value: [Result],
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
) {
var values: [Result] = []
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
for x in self {
let (y, pb) = valueWithPullback(at: x, in: body)
values.append(y)
pullbacks.append(pb)
}
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
}
return (value: values, pullback: pullback)
}
}

extension Array where Element: Differentiable {
@inlinable
@differentiable(wrt: (self, initialResult))
Expand Down Expand Up @@ -336,34 +362,3 @@ extension Array where Element: Differentiable {
)
}
}

extension Array where Element: Differentiable {
@inlinable
@differentiable(wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> [Result] {
map(body)
}

@inlinable
@derivative(of: differentiableMap)
internal func _vjpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> (
value: [Result],
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
) {
var values: [Result] = []
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
for x in self {
let (y, pb) = valueWithPullback(at: x, in: body)
values.append(y)
pullbacks.append(pb)
}
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
}
return (value: values, pullback: pullback)
}
}
File renamed without changes.

0 comments on commit 2e690e2

Please sign in to comment.