Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,27 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector
}
}

// Differential operators for `Tracked<Float>`.
public func gradient(
at x: Tracked<Float>, in f: @differentiable (Tracked<Float>) -> Tracked<Float>
) -> Tracked<Float> {
return pullback(at: x, in: f)(1)
// Differential operators for `Tracked<T>`.

public func gradient<T, U>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differential operators in this file can actually be deleted once Tracked conforms to FloatingPoint. Were you working in this direction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC,Tracked: FloatingPoint conformance caused * operator lookup for @differentiating(*) to become ambiguous.

That seems workaround-able by using @differentiable(vjp: ...) for now. We should probably investigate fixing @differentiating(*) ambiguous lookup (and @differentiating original declaration lookup for initializers/subscripts/properties) sometime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei, yes it would be good to make Tracked conform to FloatingPoint, but have issues that Dan mentions. I had already filed a bug: https://bugs.swift.org/browse/TF-926

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single gradient operator definitions for (...) -> Tracked<T> functions is great!
For reference: TF-927 tracks the confusing type-checker error we encountered this morning.

at x: T, in f: @differentiable (T) -> Tracked<U>
) -> T.TangentVector
where U : FloatingPoint, U.TangentVector == U {
return pullback(at: x, in: f)(Tracked<U>(1))
}

public func gradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> Tracked<R>
) -> (T.TangentVector, U.TangentVector)
where R : FloatingPoint, R.TangentVector == R {
return pullback(at: x, y, in: f)(Tracked<R>(1))
}

public func gradient(
at x: Tracked<Float>, _ y: Tracked<Float>,
in f: @differentiable (Tracked<Float>, Tracked<Float>) -> Tracked<Float>
) -> (Tracked<Float>, Tracked<Float>) {
return pullback(at: x, y, in: f)(1)
public func valueWithGradient<T, U : FloatingPoint>(
at x: T, in f: @differentiable (T) -> Tracked<U>
) -> (value: Tracked<U>, gradient: T.TangentVector) {
let (y, pullback) = valueWithPullback(at: x, in: f)
return (y, pullback(Tracked<U>(1)))
}

public extension Differentiable {
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ArrayAutoDiffTests.test("ArraySubscript") {
gradient(at: [2, 3, 4, 5, 6, 7], in: sumFirstThree))
}

ArrayAutoDiffTests.test("ArrayLiteral") {
ArrayAutoDiffTests.testWithLeakChecking("ArrayLiteral") {
func twoElementLiteral(_ x: Tracked<Float>, _ y: Tracked<Float>) -> [Tracked<Float>] {
return [x, y]
}
Expand Down Expand Up @@ -90,7 +90,7 @@ ArrayAutoDiffTests.test("ArrayConcat") {
in: sumFirstThreeConcatted))
}

ArrayAutoDiffTests.test("Array.init(repeating:count:)") {
ArrayAutoDiffTests.testWithLeakChecking("Array.init(repeating:count:)") {
@differentiable
func repeating(_ x: Tracked<Float>) -> [Tracked<Float>] {
Array(repeating: x, count: 10)
Expand Down
Loading