Skip to content

Commit 9c79811

Browse files
vguerramarcrasi
authored andcommitted
[AutoDiff] Defines remaining derivatives for tgmath functions. (#28108)
1 parent 8ff83a3 commit 9c79811

File tree

2 files changed

+85
-14
lines changed

2 files changed

+85
-14
lines changed

stdlib/public/Platform/tgmath.swift.gyb

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,11 @@ public func fabs<T: FloatingPoint>(_ x: T) -> T {
2020
}
2121

2222
@_transparent
23-
// SWIFT_ENABLE_TENSORFLOW
24-
@differentiable(
25-
vjp: _vjpSqrt
26-
where T : Differentiable & FloatingPoint, T == T.TangentVector
27-
)
2823
public func sqrt<T: FloatingPoint>(_ x: T) -> T {
2924
return x.squareRoot()
3025
}
3126

3227
@_transparent
33-
// SWIFT_ENABLE_TENSORFLOW
34-
@differentiable(
35-
wrt: (x, y, z),
36-
vjp: _vjpFma
37-
where T : Differentiable & FloatingPoint, T == T.TangentVector
38-
)
3928
public func fma<T: FloatingPoint>(_ x: T, _ y: T, _ z: T) -> T {
4029
return z.addingProduct(x, y)
4130
}
@@ -95,22 +84,75 @@ public func frexp<T: BinaryFloatingPoint>(_ x: T) -> (T, Int) {
9584

9685
// SWIFT_ENABLE_TENSORFLOW
9786
@usableFromInline
87+
@differentiating(sqrt)
9888
func _vjpSqrt<T: FloatingPoint & Differentiable> (
9989
_ x: T
100-
) -> (T, (T) -> T) where T == T.TangentVector {
101-
let value = x.squareRoot()
90+
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
91+
let value = sqrt(x)
10292
return (value, { v in v / (2 * value) })
10393
}
10494

10595
@usableFromInline
96+
@differentiating(fma)
10697
func _vjpFma<T: FloatingPoint & Differentiable> (
10798
_ x: T,
10899
_ y: T,
109100
_ z: T
110-
) -> (T, (T) -> (T, T, T)) where T == T.TangentVector {
101+
) -> (value: T, pullback: (T) -> (T, T, T)) where T == T.TangentVector {
111102
return (fma(x, y, z), { v in (v * y, v * x, v) })
112103
}
113104

105+
@usableFromInline
106+
@differentiating(remainder)
107+
func _vjpRemainder<T: FloatingPoint & Differentiable> (
108+
_ x: T,
109+
_ y: T
110+
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
111+
return (remainder(x, y), { v in (v, -v * ((x / y).rounded(.toNearestOrEven))) })
112+
}
113+
114+
@usableFromInline
115+
@differentiating(fmod)
116+
func _vjpFmod<T: FloatingPoint & Differentiable> (
117+
_ x: T,
118+
_ y: T
119+
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
120+
return (fmod(x, y), { v in (v, -v * ((x / y).rounded(.towardZero))) })
121+
}
122+
123+
@usableFromInline
124+
@differentiating(ceil)
125+
func _vjpCeil<T: FloatingPoint & Differentiable> (
126+
_ x: T
127+
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
128+
return (ceil(x), { v in 0 })
129+
}
130+
131+
@usableFromInline
132+
@differentiating(floor)
133+
func _vjpFloor<T: FloatingPoint & Differentiable> (
134+
_ x: T
135+
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
136+
return (floor(x), { v in 0 })
137+
}
138+
139+
@usableFromInline
140+
@differentiating(round)
141+
func _vjpRound<T: FloatingPoint & Differentiable> (
142+
_ x: T
143+
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
144+
return (round(x), { v in 0 })
145+
}
146+
147+
@usableFromInline
148+
@differentiating(trunc)
149+
func _vjpTrunc<T: FloatingPoint & Differentiable> (
150+
_ x: T
151+
) -> (value: T, pullback: (T) -> T) where T == T.TangentVector {
152+
return (trunc(x), { v in 0 })
153+
}
154+
// SWIFT_ENABLE_TENSORFLOW END
155+
114156
%for T in ['Float','Double']:
115157
@available(swift, deprecated: 4.2, renamed: "scalbn")
116158
@_transparent
@@ -233,6 +275,7 @@ func _vjpErf(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
233275
func _vjpErfc(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
234276
return (erfc(x), { v in v * -${T}(M_2_SQRTPI) * exp(-x * x) })
235277
}
278+
// SWIFT_ENABLE_TENSORFLOW END
236279
% if T == 'Float80':
237280
#endif
238281
% end

test/stdlib/tgmath.swift.gyb

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
5353
file: file, line: line)
5454
}
5555

56+
func checkGradient<T: BinaryFloatingPoint & Differentiable>(
57+
_ f: @differentiable (T, T) -> T,
58+
_ x: T,
59+
_ y: T)
60+
where T == T.TangentVector {
61+
let eps = T(0.01)
62+
let grad = gradient(at: x, y, in: f)
63+
let dfdx = (f(x + eps, y) - f(x, y)) / eps
64+
let dfdy = (f(x, y + eps) - f(x, y)) / eps
65+
expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192)
66+
expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192)
67+
}
68+
5669
%{
5770
unary = [
5871
'acos', 'asin', 'atan',
@@ -273,6 +286,21 @@ MathTests.test("gradient_${T}") {
273286
expectEqualWithTolerance(5.0, fmaGrad.0, ulps: 16)
274287
expectEqualWithTolerance(4.0, fmaGrad.1, ulps: 16)
275288
expectEqualWithTolerance(1.0, fmaGrad.2, ulps: 16)
289+
expectEqualWithTolerance(0.0, gradient(at: 2.0 as ${T}, in: { ceil($0) }), ulps: 16)
290+
expectEqualWithTolerance(0.0, gradient(at: 2.0 as ${T}, in: { floor($0) }), ulps: 16)
291+
expectEqualWithTolerance(0.0, gradient(at: 2.0 as ${T}, in: { round($0) }), ulps: 16)
292+
expectEqualWithTolerance(0.0, gradient(at: 2.0 as ${T}, in: { trunc($0) }), ulps: 16)
293+
for a in -10...10 {
294+
let x = ${T}(a)
295+
for b in -10...10 {
296+
let y = ${T}(b)
297+
guard b != 0 && remainder(x, y).sign == remainder(x + ${T}(0.001), y).sign &&
298+
remainder(x, y).sign == remainder(x, y + ${T}(0.001)).sign
299+
else { continue }
300+
checkGradient({ remainder($0, $1) }, x, y)
301+
checkGradient({ fmod($0, $1) }, x, y)
302+
}
303+
}
276304
}
277305
%end
278306

0 commit comments

Comments
 (0)