Skip to content

Conversation

xedin
Copy link
Contributor

@xedin xedin commented Oct 9, 2025

This appears to only support synthesized conformances because operators in such cases use different names - __derived_*.

These days devirtualization like this is performed as part of mandatory inlining. And this approach doesn't stack well with features like MemberImportVisibility because there is change to check whether witness is available or not.

@xedin
Copy link
Contributor Author

xedin commented Oct 9, 2025

@asl and @kovdan01 This change breaks some test-cases in AutoDiff directory because it looks like the mandatory pass replies on this hack. The question is whether it's okay to inline @_transparent before "differentiation" and if so I think the differentiation pass should be moved down in the pipeline to after mandatory inlining?...

@xedin
Copy link
Contributor Author

xedin commented Oct 9, 2025

@swift-ci please test source compatibility

@asl
Copy link
Contributor

asl commented Oct 9, 2025

and if so I think the differentiation pass should be moved down in the pipeline to after mandatory inlining?...

It cannot be moved there. One reason is that derivatives are defined on floating point types (Float, Double, etc.) After mandatory inlining no derivatives would be found as nothing is available on Builtin.IEEE types.

@xedin
Copy link
Contributor Author

xedin commented Oct 9, 2025

@asl What are the alternatives here? We cannot be doing this during CSApply unfortunately because it creates more problems then it solves (which is pretty much only AutoDiff) at the moment.

@asl
Copy link
Contributor

asl commented Oct 9, 2025

This actually very close to the known recent issue when the following code stopped to work in Swift 6 mode, while works in Swift 5:

import _Differentiation
@differentiable(reverse)
func addition(a: SIMD8<Double>, b: SIMD8<Double>) -> SIMD8<Double> {
    a + b
}

It used to be direct static call in Swift 5 mode and in Swift 6 it is virtual call exactly due to the removed code. We're having protocol AdditiveArithmetics, then struct SIMD : AdditiveArithmetics. and struct SIMD8: SIMD. With operator+ defined in SIMD. SIMD is not Sendable, SIMD8 is.

In Swift 5 mode it asks for operator+ and witness and requirement types match. In Swift 6 mode it asks for @sendable operator+, but the witness type is not sendable. So it emits virtual call and emits a conversion to @sendable.

The virtual call to protocol method is not differentiable as we cannot retroactively register derivatives for protocols, we simply do not have a way to represent it now using existing witness tables. See #54231 for some discussion.

@asl
Copy link
Contributor

asl commented Oct 9, 2025

@asl What are the alternatives here? We cannot be doing this during CSApply unfortunately because it creates more problems then it solves (which is pretty much only AutoDiff) at the moment.

Well... we need to be able to register custom derivatives for protocol methods (as mentioned in the issue). The code here is not actually autodiff specific it just happened to fit nicely into derivative registration and lookup :)

I was planning to work on this, however, I'm not sure this will be doable in the nearest future given that some common swift changes takes months to receive review comments, unfortunately...

Can it be made temporary conditional for swift 5 mode? It does not work in swift 6 due to reasons outlined above :) And we can try to speed-up working on retroactive protocol method differentiability that will essentially restore the implementation here.

Which tests are currently broken?

@xedin
Copy link
Contributor Author

xedin commented Oct 10, 2025

Local failures:

  Swift(macosx-arm64) :: AutoDiff/SILOptimizer/generics.swift
  Swift(macosx-arm64) :: AutoDiff/stdlib/simd.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/class_differentiation.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/differentiable_property.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/existential.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/forward_mode_simd.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/forward_mode_simple.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/method.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/repeated_calls.swift
  Swift(macosx-arm64) :: AutoDiff/validation-test/simple_math.swift

We could XFAIL them for now since you are planning to work on this.

@asl
Copy link
Contributor

asl commented Oct 10, 2025

We could XFAIL them for now since you are planning to work on this.

Let me check them locally with this PR and I will return to you.

@xedin
Copy link
Contributor Author

xedin commented Oct 10, 2025

Thanks!

@slavapestov
Copy link
Contributor

@asl I suspect you could add some code to DifferentiationTransformer::emitDerivativeFunctionReference() which will handle a WitnessMethodInst that contains a concrete conformance in the same way that you handle a FunctionRefInst today. I couldn't quite get this to work, but it should be possible.

There are two key steps:

  • You can call lookUpFunctionInWitnessTable() to get the SILFunction for the witness thunk
  • You can call getWitnessMethodSubstitutions() to convert the SubstitutionMap for the requirement's generic signature into a SubstitutionMap for the witness thunk

If you get this working, it will fix your Swift 6 mode issue with operators as well. The devirtualization stopped working on accident in Swift 6 because of @Sendable, this was not intentional but nobody noticed because as @xedin mentioned this is handled in SIL now anyway.

We would like to remove this devirtualization because it is fundamentally incompatible with the upcoming MemberImportVisibility feature as well, and the fact that AutoDiff relies on it is a major hurdle.

@slavapestov
Copy link
Contributor

@swift-ci Please test source compatibility

@asl
Copy link
Contributor

asl commented Oct 10, 2025

@slavapestov

The devirtualization stopped working on accident in Swift 6 because of @Sendable, this was not intentional but nobody noticed because as @xedin mentioned this is handled in SIL now anyway.

Well, we noticed :) Like in the SIMD example as above. It's just the user's codebase only started being moved to swift 6 recently.

suspect you could add some code to DifferentiationTransformer::emitDerivativeFunctionReference() which will handle a WitnessMethodInst that contains a concrete conformance in the same way that you handle a FunctionRefInst today.

My problem here is the protocol methods derivatives representation as we need to be able to lookup for derivatives registered retrospectively. Consider e.g. the following code (it won't compile right now due to sema prohibiting this, plust couple of existing PRs):

import _Differentiation

protocol P {
  func add(rhs: Self) -> Self
}

extension P where Self: Differentiable {
  @derivative(of: add)
  func vjpAdd(rhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
    return (add(rhs:rhs), { v in (v, v) })
  }
}

struct S : P, Differentiable {
  var s : Float;
  func add(rhs: Self) -> Self {
    return S(s: s + rhs.s);
  }
}

@differentiable(reverse where T: Differentiable)
func foo<T : P>(x : T, y : T) -> T {
  return x.add(rhs: y);
}

func bar(x : Float) -> Float {
  let s = S(s: x)
  return foo(x: s, y: s).s;
}

let _ = gradient(at: Float(1.0), of: bar)

The main interesting function here is foo which is conditionally differentiable if T : Differentiable

What we're having:

  • Differentiability witness for P.add:
sil_differentiability_witness ... <Self where Self : Differentiable, Self : P> @$s4main1PP3add3rhsxx_tF : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> ...
  • protocol witness for P in S
sil_witness_table hidden S: P module main {
  method #P.add: <Self where Self : P> (Self) -> (Self) -> Self : @$s4main1SVAA1PA2aDP3add3rhsxx_tFTW   // protocol witness for P.add(rhs:) in conformance S
}
  • protocol witness for Differentiable in S
sil_witness_table hidden S: Differentiable module main { ... }

Now the problem is that we're having a derivative for P.add in foo only if T : Differentiable. So it seems we're missing some kind of "joint" or "unified" witness table here with methods that appear when type conforms to multiple protocols at the same time. So, foo itself is calling P.add and this call is fulfilled via protocol witness for P in S. But in the derivative of foo we need to find the derivative of P.add and it should be coming from where...? The method
(derivative) only appears when type conforms to both protocols at the same time.

If P would be Differentiable, then we'd just emit method #P.vjpAdd in protocol witness for P in `S. But it is not the case here, we're having separate witness tables for different protocols.

@slavapestov
Copy link
Contributor

My problem here is the protocol methods derivatives representation as we need to be able to lookup for derivatives registered retrospectively. Consider e.g. the following code (it won't compile right now due to sema prohibiting this, plust couple of existing PRs):

I couldn’t really follow your example since I’m not familiar with the jargon used here, but from a semantic standpoint a witness_method with a concrete conformance is essentially equivalent to a function_ref for the protocol witness thunk.

@asl
Copy link
Contributor

asl commented Oct 10, 2025

@slavapestov

witness_method with a concrete conformance is essentially equivalent to a function_ref for the protocol witness thunk.

Well, let me show some example to how it is currently implemented.

Consider the following (supported) case:

import _Differentiation

protocol P : Differentiable {
  @differentiable(reverse, wrt:(self, rhs))
  func add(rhs: Self) -> Self
}

struct S : P {
  var s : Float;
  func add(rhs: Self) -> Self {
    return S(s: s + rhs.s);
  }
}

@differentiable(reverse where T: Differentiable)
func foo<T : P>(x : T, y : T) -> T {
  return x.add(rhs: y);
}

func bar(x : Float) -> Float {
  let s = S(s: x)
  return foo(x: s, y: s).s;
}

let _ = gradient(at: Float(1.0), of: bar)

It is currently works as follows:

  • Two additional methods are emitted in witness table:
sil_witness_table hidden S: P module proto {
  base_protocol Differentiable: S: Differentiable module proto
  method #P.add: <Self where Self : P> (Self) -> (Self) -> Self : @$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW  // protocol witness for P.add(rhs:) in conformance S
  method #P.add!jvp.SS.<Self where Self : P>: <Self where Self : P> (Self) -> (Self) -> Self : @AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_jvp_SS  // AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_jvp_SS
  method #P.add!vjp.SS.<Self where Self : P>: <Self where Self : P> (Self) -> (Self) -> Self : @AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_vjp_SS  // AD__$s5proto1SVAA1PA2aDP3add3rhsxx_tFTW_vjp_SS
}

implementing forward and reverse derivative.

  • foo is is essentially just a witness_method call:
// foo<A>(x:y:)
sil hidden @$s5proto3foo1x1yxx_xtAA1PRzlF : $@convention(thin) <T where T : P> (@in_guaranteed T, @in_guaranteed T) -> @out T {
...
  %5 = witness_method $T, #P.add : <Self where Self : P> (Self) -> (Self) -> Self : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // user: %6
  %6 = apply %5<T>(%0, %2, %1) : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
...
} // end sil function '$s5proto3foo1x1yxx_xtAA1PRzlF'
  • The differential transforms changes method from #P.add to e.g. #P.add!jvp.SS:
// reverse-mode derivative of foo<A>(x:y:)
sil hidden @$s5proto3foo1x1yxx_xtAA1PRzlFAaERzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) {
...
  %7 = witness_method $τ_0_0, #P.add!jvp.SS.<Self where Self : P> : <Self where Self : P> (Self) -> (Self) -> Self : $@convention(witness_method: P) <τ_0_0 where τ_0_0 : P> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %8
...

Now, as in my example above. If P itself is not Differentiable, then I'd assume lookUpFunctionInWitnessTable for witness_method inside foo<T> is not a solution as the type is not concrete and we simply have no witness table anywhere to emit something like P.add!jvp.SS method as in the example above.

I guess lookUpFunctionInWitnessTable would only work in the cases like in SIMD example above (or code after this PR) where we'd have e.g.:

  %11 = witness_method $SIMD8<Double>, #AdditiveArithmetic."+" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (Self, Self) -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : Ad
ditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 // user: %12
  %12 = apply %11<SIMD8<Double>>(%4, %7, %9, %6) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ
_0_0

Am I understanding correctly?

@asl
Copy link
Contributor

asl commented Oct 10, 2025

@xedin I checked the failed tests. I would prefer not to XFAIL them as a whole as they contain other useful passed tests. Will you maybe just comment out the failed ones? And I will try to "devirtualize" things during differential transform.

@xedin
Copy link
Contributor Author

xedin commented Oct 10, 2025

Sure, that’s what I meant by XFAIL really, I will file a github issue to keep track of re-enabling them as well.

@@ -1,5 +1,7 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s

// Operators are no longer devirtualized at AST level, it's done during SIL optimization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want, you could add an -emit-sil test to CHECK: that the witness_method is now a function_ref

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asl
Copy link
Contributor

asl commented Oct 10, 2025

@slavapestov I'm afraid your approach only works for the methods defined in the same module that is being compiled (so it won't solve SIMD case). The reason is as follows: while for the function call a witness_method with a concrete conformance is essentially equivalent to a function_ref for the protocol witness thunk. It is not so for differentiation.

The way autodiff works is that when we're seeing a function call, we try to find or synthesize the derivative. The derivative lookup is done via differentiability witness tables. If there is no differentiability witness for the function we can trigger a differentiation of the function provided that it is not external.

The important difference here is that derivative is registered for original operator SIMD<8>.operator+ and not for the protocol witness thunk! So, from the autodiff perspective protocol witness thunk does not have derivative available and the only way to obtain it is to trigger differentiation of its body. For local protocol witness thunks it is the option, yes, as we have their bodies available and we can emit derivatives for them.

While we can import body of protocol witness thunk for AdditiveArithmetics.operator+ in SIMD<8> it is too late to trigger differentiation of it. As original SIMD<8>.operator+ is already inlined there, things are optimized and all references to provided derivative is already lost.

Maybe one way to circumvent this is to trigger differentiation of protocol witness thunks as soon as derivative is registered for the method, so we can look for this derivative as well.

@xedin That said, I was able to apply @slavapestov's suggestion, it resolved ~50% of autodiff test failures. The ones remained are SIMD and Tracket<T> ones. Let me cleanup the code a bit, so it could be included here as well.

This appears to only support synthesized conformances because
operators in such cases use different names - `__derived_*`.

These days devirtualization like this is performed as part of
mandatory inlining. And this approach doesn't stack well with
features like `MemberImportVisibility` because there is change
to check whether witness is available or not.
@asl
Copy link
Contributor

asl commented Oct 10, 2025

@xedin asl@aef0cc9 implements @slavapestov suggestion. It does not always help, but covers local cases.

@xedin
Copy link
Contributor Author

xedin commented Oct 10, 2025

@asl Sounds good, please push it to my branch. Are you going to handle xfail'ing the not-yet-fixed ones too or should I?

@xedin xedin force-pushed the remove-csapply-operator-devirt branch from ea9a8b3 to 9d00c7b Compare October 10, 2025 23:52
@asl
Copy link
Contributor

asl commented Oct 11, 2025

@asl Sounds good, please push it to my branch. Are you going to handle xfail'ing the not-yet-fixed ones too or should I?

I'm going to play a bit with idea of automatically triggering differentiation for protocol witness thunks if there is custom derivative registered for this. This might fix the remaining issues (hopefully)...

@asl
Copy link
Contributor

asl commented Oct 11, 2025

@asl Sounds good, please push it to my branch. Are you going to handle xfail'ing the not-yet-fixed ones too or should I?

I am going to rewrite some tests not to use Tracked<>. We will loose leak checking, but at least keep the coverage. So we'd only XFAIL SIMD ones for now.

@xedin
Copy link
Contributor Author

xedin commented Oct 11, 2025

Sounds good!

@asl asl requested review from asl and eeckstein as code owners October 13, 2025 02:49
@asl
Copy link
Contributor

asl commented Oct 13, 2025

@swift-ci please test

@asl
Copy link
Contributor

asl commented Oct 13, 2025

@xedin I updated tests not to use Tracked<T> where appropriate and pushed to your branch

@xedin
Copy link
Contributor Author

xedin commented Oct 13, 2025

Looks like there is a crash in test/AutoDiff/validation-test/method.swift now and everything else is happy.

@asl
Copy link
Contributor

asl commented Oct 13, 2025

@xedin Yeah. But looks like only linux-specific and happens in autodiff closure specialization:

[2025-10-13T04:12:02.675Z] swift runtime: unknown backtracing setting 'warnings'
[2025-10-13T04:12:02.675Z] Assertion failed: (result && "instruction not cloned"), function cloneInst at SILBridging.cpp:606.
[2025-10-13T04:12:02.675Z] (to display assertion configuration options: -Xllvm -assert-help)
[2025-10-13T04:12:02.675Z] 
[2025-10-13T04:12:02.675Z] Please submit a bug report (https://swift.org/contributing/#reporting-bugs) and include the crash backtrace.
[2025-10-13T04:12:02.675Z] Stack dump:
[2025-10-13T04:12:02.675Z] 0.	Program arguments: /home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend -frontend -c -primary-file /home/build-user/swift/test/AutoDiff/validation-
...
[2025-10-13T04:12:02.675Z] 4.	While running pass #29366 SILFunctionTransform "AutodiffClosureSpecialization" on SILFunction "@$s4main15CustomParameterV20multiplied_constSelf4withS2f_tFTJrSUpSr".
[2025-10-13T04:12:02.675Z]  for 'dMultiplied_wrtOther(with:)' (at /home/build-user/swift/test/AutoDiff/validation-test/method.swift:418:3)
[2025-10-13T04:12:02.675Z] 5.	Assertion failed: (result && "instruction not cloned"), function cloneInst at SILBridging.cpp:606.
[2025-10-13T04:12:02.675Z] | 	(to display assertion configuration options: -Xllvm -assert-help)
[2025-10-13T04:12:02.675Z]  #0 0x000055fb9f78e7b8 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend+0x89337b8)
[2025-10-13T04:12:02.675Z]  #1 0x000055fb9f78bfa5 llvm::sys::RunSignalHandlers() (/home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend+0x8930fa5)
[2025-10-13T04:12:02.675Z]  #2 0x000055fb9f78f561 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
[2025-10-13T04:12:02.675Z]  #3 0x00007f8ebad0f520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
[2025-10-13T04:12:02.675Z]  #4 0x00007f8ebad639fc pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x969fc)
[2025-10-13T04:12:02.675Z]  #5 0x00007f8ebad0f476 gsignal (/lib/x86_64-linux-gnu/libc.so.6+0x42476)
[2025-10-13T04:12:02.675Z]  #6 0x00007f8ebacf57f3 abort (/lib/x86_64-linux-gnu/libc.so.6+0x287f3)
[2025-10-13T04:12:02.675Z]  #7 0x000055fb999e92f2 (/home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend+0x2b8e2f2)
[2025-10-13T04:12:02.675Z]  #8 0x000055fb999e92a4 (/home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend+0x2b8e2a4)
[2025-10-13T04:12:02.675Z]  #9 0x000055fb985ca5f2 BridgedCloner::clone(BridgedInstruction) const (/home/build-user/build/buildbot_linux/swift-linux-x86_64/bin/swift-frontend+0x176f5f2)
[2025-10-13T04:12:02.675Z] #10 0x000055fb977092d0 $s3SIL6ClonerV16cloneRecursively5value15customGetClonedAA5Value_pSgAaG_p_AC0gH6ResultOyx_GAaG_p_ACyxGztXEtF9Optimizer19FunctionPassContextV_Tg5Tf4enn_n crtstuff.c:0:0
[2025-10-13T04:12:02.675Z] #11 0x000055fb975ff2d2 $sSlsE3mapySayqd__Gqd__7ElementQzKXEKlFSay3SIL7OperandV15closureArgument_AE22SingleValueInstructionC11rootClosuretG_AE0H0_pTg50100$s9Optimizer18SpecializationInfo028_D216D6F3C0B5FE4DA71AFDABCA3E3C2ALLV13cloneClosures5usingSay3SIL5h51_pGAG6ClonerVyAA19FunctionPassContextVGz_tFAgH_pAG7d6V_AG06g3Q11I7CtXEfU_AE6ClonerVy9Optimizer19FunctionPassContextVGAQ0N4Info01_pqrstuvwqY1ALLVTf1cn_n crtstuff.c:0:0
[2025-10-13T04:12:02.675Z] #12 0x000055fb97619cf9 

Tagging @kovdan01 to take a look. Likely due to recent autodiff closure specialization pass rewrite by @eeckstein

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants