Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ public func differentiableFunction<T : Differentiable, R : Differentiable>(
from vjp: @escaping (T)
-> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector)
) -> @differentiable (T) -> R {
@differentiable(vjp: _vjp)
func original(_ x: T) -> R {
return vjp(x).value
}
func _vjp(_ x: T) -> (R, (R.CotangentVector) -> T.CotangentVector) {
@differentiating(original)
func derivative(_ x: T)
-> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) {
return vjp(x)
}
return original
Expand All @@ -163,12 +164,14 @@ public func differentiableFunction<T, U, R>(
-> (T.CotangentVector, U.CotangentVector))
) -> @differentiable (T, U) -> R
where T : Differentiable, U : Differentiable, R : Differentiable {
@differentiable(vjp: _vjp)
func original(_ x: T, _ y: U) -> R {
return vjp(x, y).value
}
func _vjp(_ x: T, _ y: U)
-> (R, (R.CotangentVector) -> (T.CotangentVector, U.CotangentVector)) {
@differentiating(original)
func derivative(_ x: T, _ y: U)
-> (value: R,
pullback: (R.CotangentVector)
-> (T.CotangentVector, U.CotangentVector)) {
return vjp(x, y)
}
return original
Expand Down
26 changes: 11 additions & 15 deletions stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1611,8 +1611,6 @@ extension ${Self} {

extension ${Self} {
@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpNegate(x:))
public static prefix func - (x: ${Self}) -> ${Self} {
return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value))
}
Expand All @@ -1622,7 +1620,9 @@ extension ${Self} {
@usableFromInline
@_transparent
// SWIFT_ENABLE_TENSORFLOW
static func _vjpNegate(x: ${Self}) -> (${Self}, (${Self}) -> ${Self}) {
@differentiating(-)
static func _vjpNegate(x: ${Self})
-> (value: ${Self}, pullback: (${Self}) -> ${Self}) {
return (-x, { v in -v })
}
}
Expand Down Expand Up @@ -1748,35 +1748,27 @@ extension ${Self} {

extension ${Self} {
@_transparent
/// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpAdd(lhs:rhs:))
public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
var lhs = lhs
lhs += rhs
return lhs
}

@_transparent
/// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpSubtract(lhs:rhs:))
public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
var lhs = lhs
lhs -= rhs
return lhs
}

@_transparent
/// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpMultiply(lhs:rhs:))
public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
var lhs = lhs
lhs *= rhs
return lhs
}

@_transparent
/// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpDivide(lhs:rhs:))
public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} {
var lhs = lhs
lhs /= rhs
Expand All @@ -1789,33 +1781,37 @@ extension ${Self} {
extension ${Self} {
@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(+)
static func _vjpAdd(
lhs: ${Self}, rhs: ${Self}
) -> (${Self}, (${Self}) -> (${Self}, ${Self})) {
) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
return (lhs + rhs, { v in (v, v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(-)
static func _vjpSubtract(
lhs: ${Self}, rhs: ${Self}
) -> (${Self}, (${Self}) -> (${Self}, ${Self})) {
) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
return (lhs - rhs, { v in (v, -v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(*)
static func _vjpMultiply(
lhs: ${Self}, rhs: ${Self}
) -> (${Self}, (${Self}) -> (${Self}, ${Self})) {
) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
return (lhs * rhs, { v in (rhs * v, lhs * v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(/)
static func _vjpDivide(
lhs: ${Self}, rhs: ${Self}
) -> (${Self}, (${Self}) -> (${Self}, ${Self})) {
) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}
}
Expand Down
8 changes: 4 additions & 4 deletions test/AutoDiff/differentiable_attr_silgen_cross_module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
_ = gradient(at: Float(1)) { x in x + x * x }

// CHECK-SILGEN-LABEL: // static Float.* infix(_:_:)
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL-LABEL: // static Float.* infix(_:_:)
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float

// CHECK-SILGEN-LABEL: // static Float.+ infix(_:_:)
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL-LABEL: // static Float.+ infix(_:_:)
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float