-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[CSApply] Don't attempt operator devirtualization #84800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@asl and @kovdan01 This change breaks some test-cases in |
@swift-ci please test source compatibility |
It cannot be moved there. One reason is that derivatives are defined on floating point types ( |
@asl What are the alternatives here? We cannot be doing this during CSApply unfortunately because it creates more problems then it solves (which is pretty much only AutoDiff) at the moment. |
This actually very close to the known recent issue when the following code stopped to work in Swift 6 mode, while works in Swift 5: import _Differentiation
@differentiable(reverse)
func addition(a: SIMD8<Double>, b: SIMD8<Double>) -> SIMD8<Double> {
a + b
} It used to be direct static call in Swift 5 mode and in Swift 6 it is virtual call exactly due to the removed code. We're having In Swift 5 mode it asks for The virtual call to protocol method is not differentiable as we cannot retroactively register derivatives for protocols, we simply do not have a way to represent it now using existing witness tables. See #54231 for some discussion. |
Well... we need to be able to register custom derivatives for protocol methods (as mentioned in the issue). The code here is not actually autodiff specific it just happened to fit nicely into derivative registration and lookup :) I was planning to work on this, however, I'm not sure this will be doable in the nearest future given that some common swift changes takes months to receive review comments, unfortunately... Can it be made temporary conditional for swift 5 mode? It does not work in swift 6 due to reasons outlined above :) And we can try to speed-up working on retroactive protocol method differentiability that will essentially restore the implementation here. Which tests are currently broken? |
Local failures:
We could XFAIL them for now since you are planning to work on this. |
Let me check them locally with this PR and I will return to you. |
Thanks! |
@asl I suspect you could add some code to There are two key steps:
If you get this working, it will fix your Swift 6 mode issue with operators as well. The devirtualization stopped working on accident in Swift 6 because of We would like to remove this devirtualization because it is fundamentally incompatible with the upcoming |
@swift-ci Please test source compatibility |
Well, we noticed :) Like in the SIMD example as above. It's just the user's codebase only started being moved to swift 6 recently.
My problem here is the protocol methods derivatives representation as we need to be able to lookup for derivatives registered retrospectively. Consider e.g. the following code (it won't compile right now due to sema prohibiting this, plust couple of existing PRs): import _Differentiation
protocol P {
func add(rhs: Self) -> Self
}
extension P where Self: Differentiable {
@derivative(of: add)
func vjpAdd(rhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (add(rhs:rhs), { v in (v, v) })
}
}
struct S : P, Differentiable {
var s : Float;
func add(rhs: Self) -> Self {
return S(s: s + rhs.s);
}
}
@differentiable(reverse where T: Differentiable)
func foo<T : P>(x : T, y : T) -> T {
return x.add(rhs: y);
}
func bar(x : Float) -> Float {
let s = S(s: x)
return foo(x: s, y: s).s;
}
let _ = gradient(at: Float(1.0), of: bar) The main interesting function here is What we're having:
Now the problem is that we're having a derivative for If |
I couldn’t really follow your example since I’m not familiar with the jargon used here, but from a semantic standpoint a |
Well, let me show some example to how it is currently implemented. Consider the following (supported) case: import _Differentiation
protocol P : Differentiable {
@differentiable(reverse, wrt:(self, rhs))
func add(rhs: Self) -> Self
}
struct S : P {
var s : Float;
func add(rhs: Self) -> Self {
return S(s: s + rhs.s);
}
}
@differentiable(reverse where T: Differentiable)
func foo<T : P>(x : T, y : T) -> T {
return x.add(rhs: y);
}
func bar(x : Float) -> Float {
let s = S(s: x)
return foo(x: s, y: s).s;
}
let _ = gradient(at: Float(1.0), of: bar) It is currently works as follows:
sil_witness_table hidden S: P module proto {
base_protocol Differentiable: S: Differentiable module proto
method #P.add: <Self where Self : P> (Self) -> (Self) -> Self : @$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW // protocol witness for P.add(rhs:) in conformance S
method #P.add!jvp.SS.<Self where Self : P>: <Self where Self : P> (Self) -> (Self) -> Self : @AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_jvp_SS // AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_jvp_SS
method #P.add!vjp.SS.<Self where Self : P>: <Self where Self : P> (Self) -> (Self) -> Self : @AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_vjp_SS // AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_vjp_SS
} implementing forward and reverse derivative.
// foo<A>(x:y:)
sil hidden @$s5proto3foo1x1yxx_xtAA1PRzlF : $@convention(thin) <T where T : P> (@in_guaranteed T, @in_guaranteed T) -> @out T {
...
%5 = witness_method $T, #P.add : <Self where Self : P> (Self) -> (Self) -> Self : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // user: %6
%6 = apply %5<T>(%0, %2, %1) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
...
} // end sil function '$s5proto3foo1x1yxx_xtAA1PRzlF'
// reverse-mode derivative of foo<A>(x:y:)
sil hidden @$s5proto3foo1x1yxx_xtAA1PRzlFAaERzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) {
...
%7 = witness_method $τ_0_0, #P.add!jvp.SS.<Self where Self : P> : <Self where Self : P> (Self) -> (Self) -> Self : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %8
... Now, as in my example above. If I guess %11 = witness_method $SIMD8<Double>, #AdditiveArithmetic."+" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (Self, Self) -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : Ad
ditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 // user: %12
%12 = apply %11<SIMD8<Double>>(%4, %7, %9, %6) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ
_0_0 Am I understanding correctly? |
@xedin I checked the failed tests. I would prefer not to |
Sure, that’s what I meant by XFAIL really, I will file a github issue to keep track of re-enabling them as well. |
@@ -1,5 +1,7 @@ | |||
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s | |||
|
|||
// Operators are no longer devirtualized at AST level, it's done during SIL optimization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want, you could add an -emit-sil
test to CHECK: that the witness_method
is now a function_ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
✅
@slavapestov I'm afraid your approach only works for the methods defined in the same module that is being compiled (so it won't solve The way autodiff works is that when we're seeing a function call, we try to find or synthesize the derivative. The derivative lookup is done via differentiability witness tables. If there is no differentiability witness for the function we can trigger a differentiation of the function provided that it is not external. The important difference here is that derivative is registered for original operator While we can import body of protocol witness thunk for Maybe one way to circumvent this is to trigger differentiation of protocol witness thunks as soon as derivative is registered for the method, so we can look for this derivative as well. @xedin That said, I was able to apply @slavapestov's suggestion, it resolved ~50% of autodiff test failures. The ones remained are |
This appears to only support synthesized conformances because operators in such cases use different names - `__derived_*`. These days devirtualization like this is performed as part of mandatory inlining. And this approach doesn't stack well with features like `MemberImportVisibility` because there is change to check whether witness is available or not.
@xedin asl@aef0cc9 implements @slavapestov suggestion. It does not always help, but covers local cases. |
@asl Sounds good, please push it to my branch. Are you going to handle xfail'ing the not-yet-fixed ones too or should I? |
ea9a8b3
to
9d00c7b
Compare
I'm going to play a bit with idea of automatically triggering differentiation for protocol witness thunks if there is custom derivative registered for this. This might fix the remaining issues (hopefully)... |
I am going to rewrite some tests not to use |
Sounds good! |
@swift-ci please test |
@xedin I updated tests not to use |
Looks like there is a crash in |
@xedin Yeah. But looks like only linux-specific and happens in autodiff closure specialization:
Tagging @kovdan01 to take a look. Likely due to recent autodiff closure specialization pass rewrite by @eeckstein |
This appears to only support synthesized conformances because operators in such cases use different names -
__derived_*
.These days devirtualization like this is performed as part of mandatory inlining. And this approach doesn't stack well with features like
MemberImportVisibility
because there is change to check whether witness is available or not.