@@ -90,6 +90,13 @@ extension SIMD {
9090
9191 /// A vector with the specified value in all lanes.
9292 @_transparent
93+ // SWIFT_ENABLE_TENSORFLOW
94+ @differentiable ( vjp: _vjpInit ( repeating: )
95+ where Self : Differentiable,
96+ Self . TangentVector : SIMD,
97+ Scalar : BinaryFloatingPoint & Differentiable,
98+ Self . TangentVector == Self,
99+ Scalar . TangentVector == Scalar)
93100 public init ( repeating value: Scalar ) {
94101 self . init ( )
95102 for i in indices { self [ i] = value }
@@ -779,29 +786,53 @@ extension SIMD where Scalar: FixedWidthInteger {
779786
780787// Implementations of floating-point operations. These should eventually all
781788// be replaced with @_semantics to lower directly to vector IR nodes.
782- extension SIMD where Scalar: FloatingPoint {
783- @_transparent
789+ extension SIMD where Scalar : FloatingPoint {
790+ @_transparent
791+ // SWIFT_ENABLE_TENSORFLOW
792+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
793+ where Self : Differentiable,
794+ Self . TangentVector : SIMD,
795+ Scalar : BinaryFloatingPoint,
796+ Self . TangentVector. Scalar : BinaryFloatingPoint)
784797 public static func + ( lhs: Self , rhs: Self ) -> Self {
785798 var result = Self ( )
786799 for i in result. indices { result [ i] = lhs [ i] + rhs[ i] }
787800 return result
788801 }
789802
790803 @_transparent
804+ // SWIFT_ENABLE_TENSORFLOW
805+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
806+ where Self : Differentiable,
807+ Self . TangentVector : SIMD,
808+ Scalar : BinaryFloatingPoint,
809+ Self . TangentVector. Scalar : BinaryFloatingPoint)
791810 public static func - ( lhs: Self , rhs: Self ) -> Self {
792811 var result = Self ( )
793812 for i in result. indices { result [ i] = lhs [ i] - rhs[ i] }
794813 return result
795814 }
796815
797816 @_transparent
817+ // SWIFT_ENABLE_TENSORFLOW
818+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
819+ where Self : Differentiable,
820+ Self . TangentVector : SIMD,
821+ Scalar : BinaryFloatingPoint,
822+ Self . TangentVector == Self)
798823 public static func * ( lhs: Self , rhs: Self ) -> Self {
799824 var result = Self ( )
800825 for i in result. indices { result [ i] = lhs [ i] * rhs[ i] }
801826 return result
802827 }
803828
804829 @_transparent
830+ // SWIFT_ENABLE_TENSORFLOW
831+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
832+ where Self : Differentiable,
833+ Self . TangentVector : SIMD,
834+ Scalar : BinaryFloatingPoint,
835+ Self . TangentVector == Self)
805836 public static func / ( lhs: Self , rhs: Self ) -> Self {
806837 var result = Self ( )
807838 for i in result. indices { result [ i] = lhs [ i] / rhs[ i] }
@@ -842,7 +873,16 @@ extension SIMD where Scalar: FloatingPoint {
842873 }
843874
844875 /// Returns the sum of the scalars in the vector.
845- @_alwaysEmitIntoClient
876+ // SWIFT_ENABLE_TENSORFLOW
877+ // FIXME: TF-545 we want the sum() func to be marked as
878+ // `@_alwaysEmitIntoClient` like before when we define the VJP
879+ @inlinable
880+ @differentiable ( vjp: _vjpSum
881+ where Self : Differentiable,
882+ Self . TangentVector : SIMD,
883+ Scalar : BinaryFloatingPoint & Differentiable,
884+ Scalar . TangentVector : BinaryFloatingPoint,
885+ Self . TangentVector == Self)
846886 public func sum( ) -> Scalar {
847887 // Implementation note: this eventually be defined to lower to either
848888 // llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1157,60 +1197,112 @@ extension SIMD where Scalar: FixedWidthInteger {
11571197extension SIMD where Scalar: FloatingPoint {
11581198
11591199 @_transparent
1200+ // SWIFT_ENABLE_TENSORFLOW
1201+ @differentiable ( vjp: _vjpNegate ( rhs: )
1202+ where Self : Differentiable,
1203+ Self . TangentVector : SIMD,
1204+ Scalar : BinaryFloatingPoint,
1205+ Self . TangentVector. Scalar : BinaryFloatingPoint)
11601206 public static prefix func - ( rhs: Self ) -> Self {
11611207 return 0 - rhs
11621208 }
11631209
11641210 @_transparent
1211+ // SWIFT_ENABLE_TENSORFLOW
1212+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1213+ where Self : Differentiable,
1214+ Self . TangentVector : SIMD,
1215+ Scalar : Differentiable & BinaryFloatingPoint,
1216+ Scalar . TangentVector : BinaryFloatingPoint,
1217+ Self . TangentVector. Scalar == Scalar . TangentVector)
11651218 public static func + ( lhs: Scalar , rhs: Self ) -> Self {
11661219 return Self ( repeating: lhs) + rhs
11671220 }
11681221
11691222 @_transparent
1223+ // SWIFT_ENABLE_TENSORFLOW
1224+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1225+ where Self : Differentiable,
1226+ Self . TangentVector : SIMD,
1227+ Scalar : Differentiable & BinaryFloatingPoint,
1228+ Scalar . TangentVector : BinaryFloatingPoint,
1229+ Self . TangentVector. Scalar == Scalar . TangentVector)
11701230 public static func - ( lhs: Scalar , rhs: Self ) -> Self {
11711231 return Self ( repeating: lhs) - rhs
11721232 }
11731233
11741234 @_transparent
1235+ // SWIFT_ENABLE_TENSORFLOW
1236+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1237+ where Self : Differentiable,
1238+ Self . TangentVector : SIMD,
1239+ Scalar : BinaryFloatingPoint & Differentiable,
1240+ Self . TangentVector == Self,
1241+ Scalar . TangentVector == Scalar)
11751242 public static func * ( lhs: Scalar , rhs: Self ) -> Self {
11761243 return Self ( repeating: lhs) * rhs
11771244 }
11781245
11791246 @_transparent
1247+ // SWIFT_ENABLE_TENSORFLOW
1248+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1249+ where Self : Differentiable,
1250+ Self . TangentVector : SIMD,
1251+ Scalar : BinaryFloatingPoint & Differentiable,
1252+ Self . TangentVector == Self,
1253+ Scalar . TangentVector == Scalar)
11801254 public static func / ( lhs: Scalar , rhs: Self ) -> Self {
11811255 return Self ( repeating: lhs) / rhs
11821256 }
11831257
11841258 @_transparent
1259+ // SWIFT_ENABLE_TENSORFLOW
1260+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1261+ where Self : Differentiable,
1262+ Self . TangentVector : SIMD,
1263+ Scalar : Differentiable & BinaryFloatingPoint,
1264+ Scalar . TangentVector : BinaryFloatingPoint,
1265+ Self . TangentVector. Scalar == Scalar . TangentVector)
11851266 public static func + ( lhs: Self , rhs: Scalar ) -> Self {
11861267 return lhs + Self( repeating: rhs)
11871268 }
11881269
11891270 @_transparent
1271+ // SWIFT_ENABLE_TENSORFLOW
1272+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1273+ where Self : Differentiable,
1274+ Self . TangentVector : SIMD,
1275+ Scalar : Differentiable & BinaryFloatingPoint,
1276+ Scalar . TangentVector : BinaryFloatingPoint,
1277+ Self . TangentVector. Scalar == Scalar . TangentVector)
11901278 public static func - ( lhs: Self , rhs: Scalar ) -> Self {
11911279 return lhs - Self( repeating: rhs)
11921280 }
11931281
11941282 @_transparent
1283+ // SWIFT_ENABLE_TENSORFLOW
1284+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1285+ where Self : Differentiable,
1286+ Self . TangentVector : SIMD,
1287+ Scalar : BinaryFloatingPoint & Differentiable,
1288+ Self . TangentVector == Self,
1289+ Scalar . TangentVector == Scalar)
11951290 public static func * ( lhs: Self , rhs: Scalar ) -> Self {
11961291 return lhs * Self( repeating: rhs)
11971292 }
11981293
11991294 @_transparent
1295+ // SWIFT_ENABLE_TENSORFLOW
1296+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1297+ where Self : Differentiable,
1298+ Self . TangentVector : SIMD,
1299+ Scalar : BinaryFloatingPoint & Differentiable,
1300+ Self . TangentVector == Self,
1301+ Scalar . TangentVector == Scalar)
12001302 public static func / ( lhs: Self , rhs: Scalar ) -> Self {
12011303 return lhs / Self( repeating: rhs)
12021304 }
12031305
1204- @_transparent
1205- public static func += ( lhs: inout Self , rhs: Self ) {
1206- lhs = lhs + rhs
1207- }
1208-
1209- @_transparent
1210- public static func -= ( lhs: inout Self , rhs: Self ) {
1211- lhs = lhs - rhs
1212- }
1213-
12141306 @_transparent
12151307 public static func *= ( lhs: inout Self , rhs: Self ) {
12161308 lhs = lhs * rhs
@@ -1407,3 +1499,159 @@ where T: SIMD, T.Scalar: FloatingPoint {
14071499 }
14081500 return result
14091501}
1502+
1503+ // SWIFT_ENABLE_TENSORFLOW
1504+ extension SIMD
1505+ where Self : Differentiable ,
1506+ TangentVector : SIMD ,
1507+ Scalar : BinaryFloatingPoint ,
1508+ TangentVector. Scalar : BinaryFloatingPoint {
1509+ @inlinable
1510+ static func _vjpAdd( lhs: Self , rhs: Self )
1511+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1512+ return ( lhs + rhs, { v in
1513+ return ( v, v)
1514+ } )
1515+ }
1516+
1517+ @inlinable
1518+ static func _vjpSubtract( lhs: Self , rhs: Self )
1519+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1520+ return ( lhs - rhs, { v in
1521+ return ( v, - v)
1522+ } )
1523+ }
1524+
1525+ @inlinable
1526+ static func _vjpNegate( rhs: Self )
1527+ -> ( Self , ( TangentVector ) -> ( TangentVector ) ) {
1528+ return ( - rhs, { v in
1529+ return - v
1530+ } )
1531+ }
1532+ }
1533+
1534+ extension SIMD
1535+ where Self : Differentiable ,
1536+ TangentVector : SIMD ,
1537+ Scalar : BinaryFloatingPoint ,
1538+ Self. TangentVector == Self {
1539+ @inlinable
1540+ static func _vjpMultiply( lhs: Self , rhs: Self )
1541+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1542+ return ( lhs * rhs, { v in
1543+ return ( v * rhs, v * lhs)
1544+ } )
1545+ }
1546+
1547+ @inlinable
1548+ static func _vjpDivide( lhs: Self , rhs: Self )
1549+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1550+ return ( lhs / rhs, { v in
1551+ ( v / rhs, - lhs / ( rhs * rhs) * v)
1552+ } )
1553+ }
1554+ }
1555+
1556+ extension SIMD
1557+ where Self : Differentiable ,
1558+ TangentVector : SIMD ,
1559+ Scalar : BinaryFloatingPoint & Differentiable ,
1560+ Scalar. TangentVector : BinaryFloatingPoint ,
1561+ TangentVector. Scalar == Scalar . TangentVector {
1562+ @inlinable
1563+ static func _vjpAdd( lhs: Scalar , rhs: Self )
1564+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1565+ return ( lhs + rhs, { v in
1566+ return ( v. sum ( ) , v)
1567+ } )
1568+ }
1569+
1570+ @inlinable
1571+ static func _vjpSubtract( lhs: Scalar , rhs: Self )
1572+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1573+ return ( lhs - rhs, { v in
1574+ return ( v. sum ( ) , - v)
1575+ } )
1576+ }
1577+
1578+ @inlinable
1579+ static func _vjpAdd( lhs: Self , rhs: Scalar )
1580+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1581+ return ( lhs + rhs, { v in
1582+ return ( v, v. sum ( ) )
1583+ } )
1584+ }
1585+
1586+ @inlinable
1587+ static func _vjpSubtract( lhs: Self , rhs: Scalar )
1588+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1589+ return ( lhs - rhs, { v in
1590+ return ( v, - v. sum ( ) )
1591+ } )
1592+ }
1593+ }
1594+
1595+ extension SIMD
1596+ where Self : Differentiable ,
1597+ TangentVector : SIMD ,
1598+ Scalar : BinaryFloatingPoint & Differentiable ,
1599+ Self. TangentVector == Self ,
1600+ Scalar. TangentVector == Scalar {
1601+ @inlinable
1602+ static func _vjpMultiply( lhs: Self , rhs: Scalar )
1603+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1604+ return ( lhs * rhs, { v in
1605+ return ( v * rhs, ( v * lhs) . sum ( ) )
1606+ } )
1607+ }
1608+
1609+ @inlinable
1610+ static func _vjpDivide( lhs: Self , rhs: Scalar )
1611+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1612+ return ( lhs / rhs, { v in
1613+ ( v / rhs, ( - lhs / ( rhs * rhs) * v) . sum ( ) )
1614+ } )
1615+ }
1616+
1617+ @inlinable
1618+ static func _vjpMultiply( lhs: Scalar , rhs: Self )
1619+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1620+ return ( lhs * rhs, { v in
1621+ return ( ( v * rhs) . sum ( ) , v * lhs)
1622+ } )
1623+ }
1624+
1625+ @inlinable
1626+ static func _vjpDivide( lhs: Scalar , rhs: Self )
1627+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1628+ return ( lhs / rhs, { v in
1629+ ( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
1630+ } )
1631+ }
1632+ }
1633+
1634+ extension SIMD
1635+ where Self : Differentiable ,
1636+ TangentVector : SIMD ,
1637+ Scalar : BinaryFloatingPoint & Differentiable ,
1638+ Scalar. TangentVector : BinaryFloatingPoint ,
1639+ TangentVector == Self {
1640+ @inlinable
1641+ func _vjpSum( ) -> ( Scalar , ( Scalar . TangentVector ) -> TangentVector ) {
1642+ return ( sum ( ) , { v in Self ( repeating: Scalar ( v) ) } )
1643+ }
1644+ }
1645+
1646+ extension SIMD
1647+ where Self : Differentiable ,
1648+ Self. TangentVector : SIMD ,
1649+ Scalar : BinaryFloatingPoint & Differentiable ,
1650+ Self. TangentVector == Self ,
1651+ Scalar. TangentVector == Scalar {
1652+ @usableFromInline
1653+ static func _vjpInit( repeating value: Scalar )
1654+ -> ( Self , ( TangentVector ) -> Scalar . TangentVector ) {
1655+ return ( Self ( repeating: value) , { v in v. sum ( ) } )
1656+ }
1657+ }
0 commit comments