diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index ebaa6296619c9..db0666109d857 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -145,11 +145,12 @@ public func differentiableFunction( from vjp: @escaping (T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) ) -> @differentiable (T) -> R { - @differentiable(vjp: _vjp) func original(_ x: T) -> R { return vjp(x).value } - func _vjp(_ x: T) -> (R, (R.CotangentVector) -> T.CotangentVector) { + @differentiating(original) + func derivative(_ x: T) + -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) { return vjp(x) } return original @@ -163,12 +164,14 @@ public func differentiableFunction( -> (T.CotangentVector, U.CotangentVector)) ) -> @differentiable (T, U) -> R where T : Differentiable, U : Differentiable, R : Differentiable { - @differentiable(vjp: _vjp) func original(_ x: T, _ y: U) -> R { return vjp(x, y).value } - func _vjp(_ x: T, _ y: U) - -> (R, (R.CotangentVector) -> (T.CotangentVector, U.CotangentVector)) { + @differentiating(original) + func derivative(_ x: T, _ y: U) + -> (value: R, + pullback: (R.CotangentVector) + -> (T.CotangentVector, U.CotangentVector)) { return vjp(x, y) } return original diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index b42f8526d2dc0..18dd4f9b9aaab 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1611,8 +1611,6 @@ extension ${Self} { extension ${Self} { @_transparent - // SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpNegate(x:)) public static prefix func - (x: ${Self}) -> ${Self} { return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value)) } @@ -1622,7 +1620,9 @@ extension ${Self} { @usableFromInline @_transparent // SWIFT_ENABLE_TENSORFLOW - static func _vjpNegate(x: ${Self}) -> (${Self}, (${Self}) -> ${Self}) { + @differentiating(-) + static func _vjpNegate(x: ${Self}) + -> (value: ${Self}, pullback: (${Self}) -> ${Self}) { return (-x, { v in -v }) } } @@ -1748,8 +1748,6 @@ extension ${Self} { extension ${Self} { @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpAdd(lhs:rhs:)) public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs += rhs @@ -1757,8 +1755,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpSubtract(lhs:rhs:)) public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs -= rhs @@ -1766,8 +1762,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpMultiply(lhs:rhs:)) public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs *= rhs @@ -1775,8 +1769,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpDivide(lhs:rhs:)) public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs /= rhs @@ -1789,33 +1781,37 @@ extension ${Self} { extension ${Self} { @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(+) static func _vjpAdd( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs + rhs, { v in (v, v) }) } @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(-) static func _vjpSubtract( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs - rhs, { v in (v, -v) }) } @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(*) static func _vjpMultiply( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs * rhs, { v in (rhs * v, lhs * v) }) } @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(/) static func _vjpDivide( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } } diff --git a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift index 41045462df35e..de31a2d02bdd6 100644 --- a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift +++ b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift @@ -7,11 +7,11 @@ _ = gradient(at: Float(1)) { x in x + x * x } // CHECK-SILGEN-LABEL: // static Float.* infix(_:_:) -// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL-LABEL: // static Float.* infix(_:_:) -// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SILGEN-LABEL: // static Float.+ infix(_:_:) -// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL-LABEL: // static Float.+ infix(_:_:) -// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float