From 4ff9a3b08a0e88d16887f3de0133e86c3806a461 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 8 Mar 2019 16:56:52 -0800 Subject: [PATCH 1/3] [AutoDiff] [stdlib] Switch to '@differentiating' as much as possible. `@differentiable(vjp: ...)` will be deprecated as soon as we support parameter selection and generic constraints with retroactive derivative registration (`@differentiating(...)`). --- stdlib/public/core/AutoDiff.swift | 13 ++++++----- .../public/core/FloatingPointTypes.swift.gyb | 22 ++++++------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index ebaa6296619c9..db0666109d857 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -145,11 +145,12 @@ public func differentiableFunction( from vjp: @escaping (T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) ) -> @differentiable (T) -> R { - @differentiable(vjp: _vjp) func original(_ x: T) -> R { return vjp(x).value } - func _vjp(_ x: T) -> (R, (R.CotangentVector) -> T.CotangentVector) { + @differentiating(original) + func derivative(_ x: T) + -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) { return vjp(x) } return original @@ -163,12 +164,14 @@ public func differentiableFunction( -> (T.CotangentVector, U.CotangentVector)) ) -> @differentiable (T, U) -> R where T : Differentiable, U : Differentiable, R : Differentiable { - @differentiable(vjp: _vjp) func original(_ x: T, _ y: U) -> R { return vjp(x, y).value } - func _vjp(_ x: T, _ y: U) - -> (R, (R.CotangentVector) -> (T.CotangentVector, U.CotangentVector)) { + @differentiating(original) + func derivative(_ x: T, _ y: U) + -> (value: R, + pullback: (R.CotangentVector) + -> (T.CotangentVector, U.CotangentVector)) { return vjp(x, y) } return original diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index b42f8526d2dc0..8ae0e9ad32587 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1611,8 +1611,6 @@ extension ${Self} { extension ${Self} { @_transparent - // SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpNegate(x:)) public static prefix func - (x: ${Self}) -> ${Self} { return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value)) } @@ -1622,7 +1620,9 @@ extension ${Self} { @usableFromInline @_transparent // SWIFT_ENABLE_TENSORFLOW - static func _vjpNegate(x: ${Self}) -> (${Self}, (${Self}) -> ${Self}) { + @differentiating(-) + static func _vjpNegate(x: ${Self}) + -> (value: ${Self}, pullback: (${Self}) -> ${Self}) { return (-x, { v in -v }) } } @@ -1748,8 +1748,6 @@ extension ${Self} { extension ${Self} { @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpAdd(lhs:rhs:)) public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs += rhs @@ -1757,8 +1755,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpSubtract(lhs:rhs:)) public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs -= rhs @@ -1766,8 +1762,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpMultiply(lhs:rhs:)) public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs *= rhs @@ -1775,8 +1769,6 @@ extension ${Self} { } @_transparent - /// SWIFT_ENABLE_TENSORFLOW - @differentiable(vjp: _vjpDivide(lhs:rhs:)) public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} { var lhs = lhs lhs /= rhs @@ -1791,7 +1783,7 @@ extension ${Self} { @_transparent static func _vjpAdd( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs + rhs, { v in (v, v) }) } @@ -1799,7 +1791,7 @@ extension ${Self} { @_transparent static func _vjpSubtract( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs - rhs, { v in (v, -v) }) } @@ -1807,7 +1799,7 @@ extension ${Self} { @_transparent static func _vjpMultiply( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs * rhs, { v in (rhs * v, lhs * v) }) } @@ -1815,7 +1807,7 @@ extension ${Self} { @_transparent static func _vjpDivide( lhs: ${Self}, rhs: ${Self} - ) -> (${Self}, (${Self}) -> (${Self}, ${Self})) { + ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } } From 83e9c7b0e287006eb1460d8c76904d17e925e1ab Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 17 Mar 2019 22:43:59 -0700 Subject: [PATCH 2/3] Add `@differentiating` attributes to floating-point operations. --- stdlib/public/core/FloatingPointTypes.swift.gyb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index 8ae0e9ad32587..18dd4f9b9aaab 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1781,6 +1781,7 @@ extension ${Self} { extension ${Self} { @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(+) static func _vjpAdd( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { @@ -1789,6 +1790,7 @@ extension ${Self} { @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(-) static func _vjpSubtract( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { @@ -1797,6 +1799,7 @@ extension ${Self} { @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(*) static func _vjpMultiply( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { @@ -1805,6 +1808,7 @@ extension ${Self} { @inlinable // FIXME(sil-serialize-all) @_transparent + @differentiating(/) static func _vjpDivide( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { From 9fdeed2608594bacce598f3b6b78a34233f5f2c7 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 20 Mar 2019 11:04:08 -0700 Subject: [PATCH 3/3] Update derivative names for stdlib operators in tests. --- .../differentiable_attr_silgen_cross_module.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift index 41045462df35e..de31a2d02bdd6 100644 --- a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift +++ b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift @@ -7,11 +7,11 @@ _ = gradient(at: Float(1)) { x in x + x * x } // CHECK-SILGEN-LABEL: // static Float.* infix(_:_:) -// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL-LABEL: // static Float.* infix(_:_:) -// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SILGEN-LABEL: // static Float.+ infix(_:_:) -// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0, 1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL-LABEL: // static Float.+ infix(_:_:) -// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf_Sf_SftSfctSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL-NEXT: sil public_external [transparent] [serialized] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] [differentiable source 0 wrt 0, 1 jvp @AD__$sSf1poiyS2f_SftFZ__jvp_src_0_wrt_0_1 vjp @$sSf7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZ] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float