From 3011cc6077fc3de1fc9781c741fe793351f08d44 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Wed, 17 Sep 2025 15:46:33 -1000 Subject: [PATCH] Correctly handle multiple semantic results for autodiff subset differential thunks. Fixes https://github.com/swiftlang/swift/issues/84365 --- lib/SILOptimizer/Differentiation/Thunk.cpp | 29 +++++++---- ...-autodiff-subset-thunks-differential.swift | 49 +++++++++++++++++++ 2 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 test/AutoDiff/compiler_crashers_fixed/issue-84365-autodiff-subset-thunks-differential.swift diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index e4328547d90d3..8081616571d9b 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -621,22 +621,33 @@ getOrCreateSubsetParametersThunkForLinearMap( // If differential thunk, deallocate local allocations and directly return // `apply` result (if it is desired). + // TODO: Unify with VJP code below if (kind == AutoDiffDerivativeFunctionKind::JVP) { SmallVector differentialDirectResults; extractAllElements(ai, builder, differentialDirectResults); SmallVector allResults; collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults); - unsigned numResults = thunk->getConventions().getNumDirectSILResults() + - thunk->getConventions().getNumDirectSILResults(); SmallVector results; - for (unsigned idx : *actualConfig.resultIndices) { - if (idx >= numResults) - break; - auto result = allResults[idx]; - if (desiredConfig.isWrtResult(idx)) - results.push_back(result); - else { + unsigned firstSemanticParamResultIdx = origFnType->getNumResults(); + for (unsigned resultIndex : *actualConfig.resultIndices) { + SILValue result; + if (resultIndex >= firstSemanticParamResultIdx) { + auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx; + result = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + } else + result = allResults[resultIndex]; + + // If result is desired: + // - Do nothing if result is indirect. + // (It was already forwarded to the `apply` instruction). + // - Push it to `results` if result is direct. + if (desiredConfig.isWrtResult(resultIndex)) { + if (result->getType().isObject()) + results.push_back(result); + } else { // Otherwise, cleanup the unused results. if (result->getType().isAddress()) builder.emitDestroyAddrAndFold(loc, result); else diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-84365-autodiff-subset-thunks-differential.swift b/test/AutoDiff/compiler_crashers_fixed/issue-84365-autodiff-subset-thunks-differential.swift new file mode 100644 index 0000000000000..ac88998f13625 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-84365-autodiff-subset-thunks-differential.swift @@ -0,0 +1,49 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// https://github.com/swiftlang/swift/issues/84365 +// Ensure autodiff subset thunks for differential correctly +// handle multiple semantic results and release unwanted +// result values + +import _Differentiation + +@differentiable(reverse,wrt: logits) +public func softSolveForwardWithQ(logits: [Float]) -> ([Float], [Float]) { + return ([Float](repeating: 0, count: 0), []) +} + +@derivative(of: softSolveForwardWithQ, wrt: logits) +public func vjpSoftSolveForwardWithQ(logits: [Float]) -> (value: ([Float], [Float]), pullback: ([Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector) { + let n = logits.count + let q = [Float](repeating: 0, count: 0) + let y = [Float](repeating: 0, count: 0) + + return ( + value: (y, q), + pullback: { _, _ in + return Array.DifferentiableView([Float](repeating: 0, count: n)) + } + ) +} + +@differentiable(reverse,wrt: logits) +public func forwardPredict(logits: [Float]) -> ([Float], [Float], [Float]) { + let (y, q) = softSolveForwardWithQ(logits: logits) + return (y, q, [0.0]) +} + +@derivative(of: forwardPredict, wrt: logits) +public func vjpForwardPredict(logits: [Float]) -> ( + value: ([Float], [Float], [Float]), + pullback: ([Float].TangentVector, [Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector +) { + let (valYQ, pb) = vjpSoftSolveForwardWithQ(logits: logits) + let (y, q) = valYQ + return ((y, q, [0.0]), { upY, upQ, _ in pb(upY, upQ) }) +} + +@differentiable(reverse,wrt: logits) +public func crossEntropyFromForwardPredict(logits: [Float]) -> Float { + let (_, q, _) = forwardPredict(logits: logits) + return q[0] + 1e-8 +}