@@ -20,22 +20,11 @@ public func fabs<T: FloatingPoint>(_ x: T) -> T {
2020}
2121
2222@_transparent
23- // SWIFT_ENABLE_TENSORFLOW
24- @differentiable (
25- vjp: _vjpSqrt
26- where T : Differentiable & FloatingPoint, T == T . TangentVector
27- )
2823public func sqrt< T: FloatingPoint > ( _ x: T ) -> T {
2924 return x. squareRoot ( )
3025}
3126
3227@_transparent
33- // SWIFT_ENABLE_TENSORFLOW
34- @differentiable (
35- wrt: ( x, y, z) ,
36- vjp: _vjpFma
37- where T : Differentiable & FloatingPoint, T == T . TangentVector
38- )
3928public func fma< T: FloatingPoint > ( _ x: T , _ y: T , _ z: T ) -> T {
4029 return z. addingProduct ( x, y)
4130}
@@ -95,22 +84,75 @@ public func frexp<T: BinaryFloatingPoint>(_ x: T) -> (T, Int) {
9584
9685// SWIFT_ENABLE_TENSORFLOW
9786@usableFromInline
87+ @differentiating ( sqrt)
9888func _vjpSqrt< T: FloatingPoint & Differentiable > (
9989 _ x: T
100- ) -> ( T , ( T ) -> T ) where T == T . TangentVector {
101- let value = x . squareRoot ( )
90+ ) -> ( value : T , pullback : ( T ) -> T ) where T == T . TangentVector {
91+ let value = sqrt ( x )
10292 return ( value, { v in v / ( 2 * value) } )
10393}
10494
10595@usableFromInline
96+ @differentiating ( fma)
10697func _vjpFma< T: FloatingPoint & Differentiable > (
10798 _ x: T ,
10899 _ y: T ,
109100 _ z: T
110- ) -> ( T , ( T ) -> ( T , T , T ) ) where T == T . TangentVector {
101+ ) -> ( value : T , pullback : ( T ) -> ( T , T , T ) ) where T == T . TangentVector {
111102 return ( fma ( x, y, z) , { v in ( v * y, v * x, v) } )
112103}
113104
105+ @usableFromInline
106+ @differentiating ( remainder)
107+ func _vjpRemainder< T: FloatingPoint & Differentiable > (
108+ _ x: T ,
109+ _ y: T
110+ ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
111+ return ( remainder ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . toNearestOrEven) ) ) } )
112+ }
113+
114+ @usableFromInline
115+ @differentiating ( fmod)
116+ func _vjpFmod< T: FloatingPoint & Differentiable > (
117+ _ x: T ,
118+ _ y: T
119+ ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
120+ return ( fmod ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . towardZero) ) ) } )
121+ }
122+
123+ @usableFromInline
124+ @differentiating ( ceil)
125+ func _vjpCeil< T: FloatingPoint & Differentiable > (
126+ _ x: T
127+ ) -> ( value: T , pullback: ( T ) -> T ) where T == T . TangentVector {
128+ return ( ceil ( x) , { v in 0 } )
129+ }
130+
131+ @usableFromInline
132+ @differentiating ( floor)
133+ func _vjpFloor< T: FloatingPoint & Differentiable > (
134+ _ x: T
135+ ) -> ( value: T , pullback: ( T ) -> T ) where T == T . TangentVector {
136+ return ( floor ( x) , { v in 0 } )
137+ }
138+
139+ @usableFromInline
140+ @differentiating ( round)
141+ func _vjpRound< T: FloatingPoint & Differentiable > (
142+ _ x: T
143+ ) -> ( value: T , pullback: ( T ) -> T ) where T == T . TangentVector {
144+ return ( round ( x) , { v in 0 } )
145+ }
146+
147+ @usableFromInline
148+ @differentiating ( trunc)
149+ func _vjpTrunc< T: FloatingPoint & Differentiable > (
150+ _ x: T
151+ ) -> ( value: T , pullback: ( T ) -> T ) where T == T . TangentVector {
152+ return ( trunc ( x) , { v in 0 } )
153+ }
154+ // SWIFT_ENABLE_TENSORFLOW END
155+
114156% for T in [ 'Float', 'Double'] :
115157@available ( swift, deprecated: 4.2 , renamed: " scalbn " )
116158@_transparent
@@ -233,6 +275,7 @@ func _vjpErf(_ x: ${T}) -> (${T}, (${T}) -> ${T}) {
233275func _vjpErfc( _ x: ${ T} ) -> ( ${ T} , ( ${ T} ) -> ${ T} ) {
234276 return ( erfc( x ) , { v in v * - ${ T} ( M_2_SQRTPI ) * exp( - x * x) } )
235277}
278+ // SWIFT_ENABLE_TENSORFLOW END
236279% if T == 'Float80 ':
237280#endif
238281% end
0 commit comments