From 00b186830cf7787b75c16671cd37ab3a49b65fd6 Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Fri, 10 Oct 2025 00:06:52 +0300 Subject: [PATCH] [AutoDiff] Closure specialization: specialize branch tracing enums This patch contains part of the changes intended to resolve #68944. 1. Closure info gathering logic. 2. Branch tracing enum specialization logic. 3. Specialization of branch tracing enum basic block arguments in VJP. 4. Specialization of branch tracing enum payload basic block arguments in pullback. --- .../ClosureSpecialization.swift | 736 ++++++++++++++++++ .../Optimizer/Utilities/FunctionTest.swift | 4 + include/swift/SIL/SILBridging.h | 1 + lib/SILOptimizer/Utils/OptimizerBridging.cpp | 68 ++ .../closure_specialization/multi_bb_bte.sil | 249 ++++++ 5 files changed, 1058 insertions(+) diff --git a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift index b8dbafcb12afa..2ccabda0bf3ce 100644 --- a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift @@ -13,6 +13,14 @@ import AST import SIL +private let verbose = false + +private func log(prefix: Bool = true, _ message: @autoclosure () -> String) { + if verbose { + debugLog(prefix: prefix, "[ADCS] " + message()) + } +} + /// Closure Specialization /// ---------------------- /// Specializes functions which take a closure (a `partial_apply` or `thin_to_thick_function` as argument. @@ -791,3 +799,731 @@ private extension Function { } } } + +extension Collection { + func getExactlyOneOrNil() -> Element? { + assert(self.count <= 1) + return self.first + } +} + +extension Type: Hashable { + func isBranchTracingEnumIn(vjp: Function) -> Bool { + return self.bridged.isAutodiffBranchTracingEnumInVJP(vjp.bridged) + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(bridged.opaqueValue) + } +} + +extension BasicBlock { + fileprivate func getBranchTracingEnumArg(vjp: Function) -> Argument? { + return self.arguments.filter { $0.type.isBranchTracingEnumIn(vjp: vjp) }.getExactlyOneOrNil() + } +} + +typealias ClosureInfoMultiBB = ( + closure: SingleValueInstruction, + capturedArgs: [Value], + subsetThunk: PartialApplyInst?, + payloadTuple: TupleInst, + idxInPayload: Int, + enumTypeAndCase: EnumTypeAndCase +) + +typealias EnumTypeAndCase = ( + enumType: Type, + caseIdx: Int +) + +typealias BTEToPredsDict = [SIL.`Type`: [SIL.`Type`]] + +typealias BTECaseToClosureListDict = [Int: [ClosureInfoMultiBB]] + +typealias SpecBTEDict = [SIL.`Type`: SIL.`Type`] + +func getCapturedArgTypesTupleForClosure( + closure: SingleValueInstruction, context: FunctionPassContext +) -> AST.`Type` { + var capturedArgTypes = [AST.`Type`]() + if let pai = closure as? PartialApplyInst { + for arg in pai.arguments { + capturedArgTypes.append(arg.type.rawType) + } + } else if closure as? ThinToThickFunctionInst != nil { + // Nothing captured + } else { + assert(false) + } + return context.getTupleType(elements: capturedArgTypes) +} + +func getBranchTracingEnumPreds(bteType: SIL.`Type`, vjp: Function) -> [SIL.`Type`] { + var btePreds = [SIL.`Type`]() + guard let enumCases = bteType.getEnumCases(in: vjp) else { + return btePreds + } + for enumCase in enumCases { + let payloadType: SIL.`Type` = enumCase.payload! + if payloadType.tupleElements.count == 0 { + continue + } + let firstTupleElementType: SIL.`Type` = payloadType.tupleElements[0] + if firstTupleElementType.isBranchTracingEnumIn(vjp: vjp) { + btePreds.append(firstTupleElementType) + } + } + return btePreds +} + +func iterateOverBranchTracingEnumPreds( + bteToPredsDict: inout BTEToPredsDict, + currentBTEType: SIL.`Type`, + vjp: Function +) { + let currentBTEPreds: [SIL.`Type`] = getBranchTracingEnumPreds(bteType: currentBTEType, vjp: vjp) + bteToPredsDict[currentBTEType] = currentBTEPreds + for currentBTEPred in currentBTEPreds { + if !bteToPredsDict.keys.contains(currentBTEPred) { + iterateOverBranchTracingEnumPreds( + bteToPredsDict: &bteToPredsDict, currentBTEType: currentBTEPred, vjp: vjp) + } + } +} + +func getBranchTracingEnumQueue(topBTEType: SIL.`Type`, vjp: Function) -> [SIL.`Type`] { + var bteToPredsDict = BTEToPredsDict() + iterateOverBranchTracingEnumPreds( + bteToPredsDict: &bteToPredsDict, + currentBTEType: topBTEType, + vjp: vjp) + var bteQueue = [SIL.`Type`]() + let totalEnums = bteToPredsDict.count + + for i in 0.. SIL.`Type` { + var silType = ty + if silType.rawType.hasArchetype { + silType = silType.mapTypeOutOfContext(in: function) + } + let remappedCanType = silType.rawType.getReducedType( + of: function.loweredFunctionType.substitutedGenericSignatureOfFunctionType.genericSignature) + let remappedSILType = remappedCanType.loweredType(in: function) + if !function.genericSignature.isEmpty { + return function.mapTypeIntoContext(remappedSILType) + } + return remappedSILType +} + +func getBranchingTraceEnumLoweredType(ed: Decl, vjp: Function) -> SIL.`Type` { + ed.bridged.getAs(NominalTypeDecl.self).declaredInterfaceType.canonical.loweredType(in: vjp) +} + +func getSourceFileFor(derivative: Function) -> SourceFile { + if let sourceFile = derivative.sourceFile { + return sourceFile + } + return derivative.bridged.getFilesForModule().withElements(ofType: FileUnit.self) { + for fileUnit in $0 { + if let sourceFile = fileUnit.asSourceFile { + return sourceFile + } + } + assert(false) + return nil + }! +} + +func cloneGenericParameters( + astContext: ASTContext, declContext: DeclContext, canonicalGenericSig: CanonicalGenericSignature +) -> GenericParamList { + var params: [BridgedGenericTypeParamDecl] = [] + for type in canonicalGenericSig.genericSignature.genericParameters { + assert(type.isGenericTypeParameter) + params.append( + BridgedGenericTypeParamDecl.createImplicit( + declContext: declContext, + name: type.nameOfGenericTypeParameter, + depth: type.depthOfGenericTypeParameter, + index: type.indexOfGenericTypeParameter, + paramKind: type.kindOfGenericTypeParameter)) + } + return GenericParamList.createParsed( + astContext, leftAngleLoc: nil, parameters: params, + genericWhereClause: nil, + rightAngleLoc: nil) +} + +func autodiffSpecializeBranchTracingEnum( + bteType: SIL.`Type`, topVJP: Function, + bteCaseToClosureListDict: BTECaseToClosureListDict, + specBTEDict: [SIL.`Type`: SIL.`Type`], + context: FunctionPassContext +) -> SIL.`Type` { + assert(specBTEDict[bteType] == nil) + + let oldED = bteType.nominal as! EnumDecl + let declContext = oldED.declContext + let astContext = declContext.astContext + + var newEDNameStr: String = oldED.name.string + "_spec" + var newPLs = [ParameterList]() + + for enumCase in bteType.getEnumCases(in: topVJP)! { + let oldPayloadTupleType: Type = enumCase.payload! + let oldEED: EnumElementDecl = enumCase.enumElementDecl + + let oldPL: ParameterList = oldEED.parameterList + assert(oldPL.size == 1) + let oldPD: BridgedParamDecl = oldPL[0] + + let closureInfosMultiBB: [ClosureInfoMultiBB] = bteCaseToClosureListDict[enumCase.index] ?? [] + + var newECDNameSuffix: String = "" + var newPayloadTupleElementTypes = [(label: Identifier, type: AST.`Type`)]() + + for idxInPayloadTuple in 0.. SpecBTEDict { + let bteQueue: [SIL.`Type`] = getBranchTracingEnumQueue(topBTEType: topBTE, vjp: topVJP) + + var specBTEDict = [SIL.`Type`: SIL.`Type`]() + for bteType in bteQueue { + let ed = bteType.nominal as! EnumDecl + let silType = remapType( + ty: getBranchingTraceEnumLoweredType(ed: ed, vjp: topVJP), function: topVJP) + + var bteCaseToClosureListDict = BTECaseToClosureListDict() + for closureInfoMultiBB in closureInfosMultiBB { + if closureInfoMultiBB.enumTypeAndCase.enumType != bteType { + continue + } + if bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx] == nil { + bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx] = [] + } + bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx]!.append( + closureInfoMultiBB) + } + + specBTEDict[silType] = autodiffSpecializeBranchTracingEnum( + bteType: silType, topVJP: topVJP, bteCaseToClosureListDict: bteCaseToClosureListDict, + specBTEDict: specBTEDict, context: context) + } + + return specBTEDict +} + +private func getPartialApplyOfPullbackInExitVJPBB(vjp: Function) -> PartialApplyInst? { + log("getPartialApplyOfPullbackInExitVJPBB: running for VJP \(vjp.name)") + guard let exitBB = vjp.blocks.filter({ $0.terminator as? ReturnInst != nil }).getExactlyOneOrNil() + else { + log("getPartialApplyOfPullbackInExitVJPBB: exit BB not found, aborting") + return nil + } + + let ri = exitBB.terminator as! ReturnInst + guard let retValDefiningInstr = ri.returnedValue.definingInstruction else { + log( + "getPartialApplyOfPullbackInExitVJPBB: return value is not defined by an instruction, aborting" + ) + return nil + } + + func handleConvertFunctionOrPartialApply(inst: Instruction) -> PartialApplyInst? { + if let pai = inst as? PartialApplyInst { + log("getPartialApplyOfPullbackInExitVJPBB: success") + return pai + } + if let cfi = inst as? ConvertFunctionInst { + if let pai = cfi.fromFunction as? PartialApplyInst { + log("getPartialApplyOfPullbackInExitVJPBB: success") + return pai + } + log( + "getPartialApplyOfPullbackInExitVJPBB: fromFunction operand of convert_function instruction is not defined by partial_apply instruction, aborting" + ) + return nil + } + log("getPartialApplyOfPullbackInExitVJPBB: unexpected instruction type, aborting") + return nil + } + + if let ti = retValDefiningInstr as? TupleInst { + log("getPartialApplyOfPullbackInExitVJPBB: return value is defined by tuple instruction") + if ti.operands.count < 2 { + log( + "getPartialApplyOfPullbackInExitVJPBB: tuple instruction has \(ti.operands.count) operands, but at least 2 expected, aborting" + ) + return nil + } + guard let lastTupleElemDefiningInst = ti.operands.last!.value.definingInstruction else { + log( + "getPartialApplyOfPullbackInExitVJPBB: last tuple element is not defined by an instruction, aborting" + ) + return nil + } + return handleConvertFunctionOrPartialApply(inst: lastTupleElemDefiningInst) + } + + return handleConvertFunctionOrPartialApply(inst: retValDefiningInstr) +} + +private func getPullbackClosureInfoMultiBB(in vjp: Function, _ context: FunctionPassContext) + -> PullbackClosureInfo +{ + let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB(vjp: vjp)! + var pullbackClosureInfo = PullbackClosureInfo(paiOfPullback: paiOfPbInExitVjpBB) + var subsetThunkArr = [SingleValueInstruction]() + + for inst in vjp.instructions { + if inst == paiOfPbInExitVjpBB { + continue + } + if inst.asSupportedClosure == nil { + continue + } + + let rootClosure = inst.asSupportedClosure! + if subsetThunkArr.contains(rootClosure) { + continue + } + + let closureInfoArr = handleNonAppliesMultiBB(for: rootClosure, context) + pullbackClosureInfo.closureInfosMultiBB.append(contentsOf: closureInfoArr) + subsetThunkArr.append( + contentsOf: closureInfoArr.filter { $0.subsetThunk != nil }.map { $0.subsetThunk! }) + } + + return pullbackClosureInfo +} + +typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = (arg: Argument, enumTypeAndCase: EnumTypeAndCase) + +// If the pullback's basic block has an argument which is a payload tuple of the +// branch tracing enum corresponding to the given VJP, return this argument and any valid combination +// of a branch tracing enum type and its case index having the same payload tuple type as the argument. +// The function assumes that no more than one such argument is present. +private func getBTEPayloadArgOfPbBBWithBTETypeAndCase(_ bb: BasicBlock, vjp: Function) + -> BTEPayloadArgOfPbBBWithBTETypeAndCase? +{ + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: basic block \(bb.shortDescription) in pullback \(bb.parentFunction.name)" + ) + guard let predBB = bb.predecessors.first else { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: the bb has no predecessors, aborting") + return nil + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: start iterating over bb args") + for arg in bb.arguments { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: \(arg)") + if !arg.type.isTuple { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: arg is not a tuple, skipping") + continue + } + + if let bi = predBB.terminator as? BranchInst { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is branch instruction") + guard let uedi = bi.operands[arg.index].value.definingInstruction as? UncheckedEnumDataInst + else { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: operand corresponding to the argument is not defined by unchecked_enum_data instruction" + ) + continue + } + let enumType = uedi.`enum`.type + if !enumType.isBranchTracingEnumIn(vjp: vjp) { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)" + ) + continue + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success") + return BTEPayloadArgOfPbBBWithBTETypeAndCase( + arg: arg, + enumTypeAndCase: ( + enumType: enumType, + caseIdx: uedi.caseIndex + ) + ) + } + + if let sei = predBB.terminator as? SwitchEnumInst { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is switch_enum instruction" + ) + let enumType = sei.enumOp.type + if !enumType.isBranchTracingEnumIn(vjp: vjp) { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)" + ) + continue + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success") + return BTEPayloadArgOfPbBBWithBTETypeAndCase( + arg: arg, + enumTypeAndCase: ( + enumType: enumType, + caseIdx: sei.getUniqueCase(forSuccessor: bb)! + ) + ) + } + } + + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: finish iterating over bb args; branch tracing enum arg not found" + ) + return nil +} + +extension PartialApplyInst { + func isSubsetThunk() -> Bool { + if self.argumentOperands.singleElement == nil { + return false + } + guard let function = self.referencedFunction else { + return false + } + return function.bridged.isAutodiffSubsetParametersThunk() + } +} + +private func handleNonAppliesMultiBB( + for rootClosure: SingleValueInstruction, + _ context: FunctionPassContext +) + -> [ClosureInfoMultiBB] +{ + log("handleNonAppliesMultiBB: running for \(rootClosure)") + let vjp = rootClosure.parentFunction + var closureInfoArr = [ClosureInfoMultiBB]() + + var closure = rootClosure + var subsetThunk = PartialApplyInst?(nil) + if rootClosure.uses.singleElement != nil { + if let pai = closure.uses.singleElement!.instruction as? PartialApplyInst { + if pai.isSubsetThunk() { + log("handleNonAppliesMultiBB: found subset thunk \(pai)") + subsetThunk = pai + closure = pai + } + } + } + + for use in closure.uses { + guard let ti = use.instruction as? TupleInst else { + log("handleNonAppliesMultiBB: unexpected use of closure, aborting: \(use)") + return [] + } + for tiUse in ti.uses { + guard let ei = tiUse.instruction as? EnumInst else { + log("handleNonAppliesMultiBB: unexpected use of payload tuple, aborting: \(tiUse)") + return [] + } + if !ei.type.isBranchTracingEnumIn(vjp: vjp) { + log( + "handleNonAppliesMultiBB: enum type \(ei.type) is not a branch tracing enum in VJP \(vjp.name), aborting" + ) + return [] + } + var capturedArgs = [Value]() + if let pai = rootClosure as? PartialApplyInst { + capturedArgs = pai.argumentOperands.map { $0.value } + } + log( + "handleNonAppliesMultiBB: creating closure info with enum type \(ei.type), case index \(ei.caseIndex), index in payload tuple \(use.index) and payload tuple \(ti)" + ) + let enumTypeAndCase = (enumType: ei.type, caseIdx: ei.caseIndex) + closureInfoArr.append( + ClosureInfoMultiBB( + closure: rootClosure, + capturedArgs: capturedArgs, + subsetThunk: subsetThunk, + payloadTuple: ti, + idxInPayload: use.index, + enumTypeAndCase: enumTypeAndCase + )) + } + } + log( + "handleNonAppliesMultiBB: created \(closureInfoArr.count) closure info entries for \(rootClosure)" + ) + return closureInfoArr +} + +extension Instruction { + fileprivate var asSupportedClosure: SingleValueInstruction? { + switch self { + case let tttf as ThinToThickFunctionInst where tttf.callee is FunctionRefInst: + return tttf + // TODO: figure out what to do with non-inout indirect arguments + // https://forums.swift.org/t/non-inout-indirect-types-not-supported-in-closure-specialization-optimization/70826 + case let pai as PartialApplyInst + where pai.callee is FunctionRefInst && pai.hasOnlyInoutIndirectArguments: + return pai + default: + return nil + } + } + fileprivate var isSupportedClosure: Bool { + asSupportedClosure != nil + } +} + +/// Represents a partial_apply of pullback capturing one or more closure arguments. +private struct PullbackClosureInfo { + let paiOfPullback: PartialApplyInst + var closureInfosMultiBB: [ClosureInfoMultiBB] = [] + + init(paiOfPullback: PartialApplyInst) { + self.paiOfPullback = paiOfPullback + } + var pullbackFn: Function { + paiOfPullback.referencedFunction! + } +} + +let getPullbackClosureInfoMultiBBTest = FunctionTest( + "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" +) { + function, arguments, context in + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context) + print("Run getPullbackClosureInfoMultiBB for VJP \(function.name): pullbackClosureInfo = (") + print(" pullbackFn = \(pullbackClosureInfo.pullbackFn.name)") + print(" closureInfosMultiBB = [") + for closureInfoMultiBB in pullbackClosureInfo.closureInfosMultiBB { + print(" ClosureInfoMultiBB(") + print(" closure: \(closureInfoMultiBB.closure)") + print(" capturedArgs: [") + for capturedArg in closureInfoMultiBB.capturedArgs { + print(" \(capturedArg)") + } + print(" ]") + let subsetThunkStr = + (closureInfoMultiBB.subsetThunk == nil ? "nil" : "\(closureInfoMultiBB.subsetThunk!)") + print(" subsetThunk: \(subsetThunkStr)") + print(" payloadTuple: \(closureInfoMultiBB.payloadTuple)") + print(" idxInPayload: \(closureInfoMultiBB.idxInPayload)") + print(" enumTypeAndCase: \(closureInfoMultiBB.enumTypeAndCase)") + print(" )") + } + print(" ]\n)\n") +} + +func getSpecBTEDict(vjp: Function, context: FunctionPassContext) -> SpecBTEDict { + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: vjp, context) + let pb = pullbackClosureInfo.pullbackFn + let enumTypeOfEntryBBArg = pb.entryBlock.getBranchTracingEnumArg(vjp: vjp)!.type + let enumDict = autodiffSpecializeBranchTracingEnums( + topVJP: vjp, topBTE: enumTypeOfEntryBBArg, + closureInfosMultiBB: pullbackClosureInfo.closureInfosMultiBB, context: context) + return enumDict +} + +func specializeBranchTracingEnumBBArgInVJP( + arg: Argument, specBTEDict: SpecBTEDict, context: FunctionPassContext +) -> Argument { + let bb = arg.parentBlock + assert(specBTEDict[arg.type] != nil) + let newType = specBTEDict[arg.type]! + return bb.insertPhiArgument( + atPosition: arg.index, type: newType, ownership: arg.ownership, context) +} + +extension SIL.`Type`: Comparable { + public static func < (lhs: SIL.`Type`, rhs: SIL.`Type`) -> Bool { + return "\(lhs)" < "\(rhs)" + } +} + +let specializeBranchTracingEnums = FunctionTest("autodiff_specialize_branch_tracing_enums") { + function, arguments, context in + let enumDict = getSpecBTEDict(vjp: function, context: context) + print( + "Specialized branch tracing enum dict for VJP \(function.name) contains \(enumDict.count) elements:" + ) + + var keys = [SIL.`Type`](enumDict.keys) + keys.sort() + for (idx, key) in keys.enumerated() { + print("non-specialized BTE \(idx): \(key.nominal!.description)") + print("specialized BTE \(idx): \(enumDict[key]!.nominal!.description)") + } + print("") +} + +let specializeBTEArgInVjpBB = FunctionTest("autodiff_specialize_bte_arg_in_vjp_bb") { + function, arguments, context in + let enumDict = getSpecBTEDict(vjp: function, context: context) + print("Specialized BTE arguments of basic blocks in VJP \(function.name):") + for bb in function.blocks { + guard let arg = bb.getBranchTracingEnumArg(vjp: function) else { + continue + } + let newArg = specializeBranchTracingEnumBBArgInVJP( + arg: arg, specBTEDict: enumDict, context: context) + print("\(newArg)") + bb.eraseArgument(at: newArg.index, context) + } + print("") +} + +func specializePayloadTupleBBArgInPullback( + arg: Argument, + enumTypeAndCase: EnumTypeAndCase, + context: FunctionPassContext +) -> Argument { + let bb = arg.parentBlock + let newEnumType = enumTypeAndCase.enumType + + var newPayloadTupleTy = SIL.`Type`?(nil) + for enumCase in newEnumType.getEnumCases(in: arg.parentFunction)! { + if enumCase.index == enumTypeAndCase.caseIdx { + newPayloadTupleTy = enumCase.payload! + break + } + } + assert(newPayloadTupleTy != nil) + + return bb.insertPhiArgument( + atPosition: arg.index, type: newPayloadTupleTy!, ownership: arg.ownership, context) +} + +let specializePayloadArgInPullbackBB = FunctionTest("autodiff_specialize_payload_arg_in_pb_bb") { + function, arguments, context in + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context) + let pb = pullbackClosureInfo.pullbackFn + let enumDict = getSpecBTEDict(vjp: function, context: context) + + print("Specialized BTE payload arguments of basic blocks in pullback \(pb.name):") + for bb in pb.blocks { + guard + let (arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase(bb, vjp: function) + else { + continue + } + + let enumType = enumDict[enumTypeAndCase.enumType]! + let newArg = specializePayloadTupleBBArgInPullback( + arg: arg, + enumTypeAndCase: (enumType: enumType, caseIdx: enumTypeAndCase.caseIdx), + context: context) + print("\(newArg)") + bb.eraseArgument(at: newArg.index, context) + } + print("") +} diff --git a/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift b/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift index 701fe2819e86d..c5ce5f6135352 100644 --- a/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift +++ b/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift @@ -43,6 +43,7 @@ public func registerOptimizerTests() { registerFunctionTests( addressOwnershipLiveRangeTest, argumentConventionsTest, + getPullbackClosureInfoMultiBBTest, interiorLivenessTest, lifetimeDependenceRootTest, lifetimeDependenceScopeTest, @@ -51,6 +52,9 @@ public func registerOptimizerTests() { localVariableReachableUsesTest, localVariableReachingAssignmentsTest, rangeOverlapsPathTest, + specializeBranchTracingEnums, + specializeBTEArgInVjpBB, + specializePayloadArgInPullbackBB, variableIntroducerTest ) diff --git a/include/swift/SIL/SILBridging.h b/include/swift/SIL/SILBridging.h index e5e940e121264..10c97880d5a3e 100644 --- a/include/swift/SIL/SILBridging.h +++ b/include/swift/SIL/SILBridging.h @@ -579,6 +579,7 @@ struct BridgedFunction { bool isConvertPointerToPointerArgument() const; bool isAddressor() const; bool isAutodiffVJP() const; + bool isAutodiffSubsetParametersThunk() const; SwiftInt specializationLevel() const; SWIFT_IMPORT_UNSAFE BridgedSubstitutionMap getMethodSubstitutions(BridgedSubstitutionMap contextSubs, BridgedCanType selfType) const; diff --git a/lib/SILOptimizer/Utils/OptimizerBridging.cpp b/lib/SILOptimizer/Utils/OptimizerBridging.cpp index 1b64bd17e644f..78ad919c8bf98 100644 --- a/lib/SILOptimizer/Utils/OptimizerBridging.cpp +++ b/lib/SILOptimizer/Utils/OptimizerBridging.cpp @@ -13,6 +13,7 @@ #include "swift/SILOptimizer/OptimizerBridging.h" #include "../../IRGen/IRGenModule.h" #include "swift/AST/SemanticAttrs.h" +#include "swift/Demangling/ManglingMacros.h" #include "swift/SIL/DynamicCasts.h" #include "swift/SIL/OSSACompleteLifetime.h" #include "swift/SIL/SILCloner.h" @@ -523,6 +524,73 @@ bool BridgedFunction::isAutodiffVJP() const { getFunction(), swift::AutoDiffFunctionComponent::VJP); } +bool BridgedFunction::isAutodiffSubsetParametersThunk() const { + Demangle::Context Ctx; + if (auto *root = Ctx.demangleSymbolAsNode(getFunction()->getName())) { + // root node has Global kind, the AutoDiffSubsetParametersThunk node (if + // present) is direct child of root. + return root->findByKind(Demangle::Node::Kind::AutoDiffSubsetParametersThunk, + /*maxDepth=*/1) != nullptr; + } + return false; +} + +// See also ASTMangler::mangleAutoDiffGeneratedDeclaration. +bool BridgedType::isAutodiffBranchTracingEnumInVJP(BridgedFunction vjp) const { + assert(vjp.isAutodiffVJP()); + EnumDecl *ed = unbridged().getEnumOrBoundGenericEnum(); + if (ed == nullptr) + return false; + + llvm::StringRef edName = ed->getNameStr(); + if (!edName.starts_with("_AD__")) + return false; + if (!llvm::StringRef(edName.data() + 5, edName.size() - 5) + .starts_with(MANGLING_PREFIX_STR)) + return false; + + // At this point, we know that the type is indeed a branch tracing enum. + // Now we need to ensure that it is the enum related to the given VJP. + + std::size_t idx = edName.rfind("__Pred__"); + assert(idx != std::string::npos); + + // Before "__Pred__", we have "_bbX", where X is a number. + // The loop calculates the start position of X. + for (; idx != 0 && std::isdigit(edName[idx - 1]); --idx) + ; + + assert(std::isdigit(edName[idx])); + assert(!std::isdigit(edName[idx - 1])); + + // The branch tracing enum decl name has the following components: + // 1) "_AD__"; + // 2) MANGLING_PREFIX; + // 3) original function name; + // 4) "_bb"; + // 5) X at position idx (see above); + // 6) the rest of the enum decl name. + // Thus, "_AD__", MANGLING_PREFIX and "_bb" must have total length less than + // idx. + std::size_t manglingPrefixSize = std::strlen(MANGLING_PREFIX_STR); + assert(idx > 5 + manglingPrefixSize + 3); + assert(std::string_view(edName.data() + idx - 3, 3) == "_bb"); + assert(std::string_view(edName.data(), 5 + manglingPrefixSize) == "_AD__$s"); + + llvm::StringRef enumOrigFuncName = + std::string_view(edName.data() + 5 + manglingPrefixSize, + idx - (5 + manglingPrefixSize + 3)); + + Demangle::Context Ctx; + if (auto *root = Ctx.demangleSymbolAsNode(vjp.getFunction()->getName())) + if (auto *node = + root->findByKind(Demangle::Node::Kind::Function, /*maxDepth=*/3)) + if (mangleNode(node).result() == enumOrigFuncName) + return true; + + return false; +} + SwiftInt BridgedFunction::specializationLevel() const { return swift::getSpecializationLevel(getFunction()); } diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil index 7d705d7eaf4b8..e04caf3329369 100644 --- a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil +++ b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil @@ -1,5 +1,6 @@ /// Multi basic block VJP, pullback accepting branch tracing enum argument. +// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER // RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK // REQUIRES: swift_in_compiler @@ -44,6 +45,35 @@ bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : @o // reverse-mode derivative of mul42(_:) sil hidden [ossa] @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional) -> (Float, @owned @callee_guaranteed (Float) -> Optional.TangentVector) { bb0(%0 : $Optional): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test5mul42yS2fSgFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test5mul42yS2fSgFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test5mul42yS2fSgFTJrSpSr contains 1 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(()) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0_spec { + // TRUNNER-NEXT: case bb0(()) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test5mul42yS2fSgFTJrSpSr: + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test5mul42yS2fSgFTJpSpSr: + // TRUNNER-EMPTY: + // CHECK: sil private [signature_optimized_thunk] [heuristic_always_inline] [ossa] @$s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional.TangentVector { // CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float): // CHECK: %[[#A6:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float @@ -271,6 +301,114 @@ bb6(%122 : $Builtin.FPIEEE32): // reverse-mode derivative of Class.method() sil hidden [ossa] @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { bb0(%0 : $Class): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test5ClassV6methodSfyFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test5ClassV6methodSfyFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct_extract %0 : $Class, #Class.stored + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct_extract %0 : $Class, #Class.stored + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $Float + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test5ClassV6methodSfyFTJrSpSr contains 3 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((Float, Float), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 1: enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 1: enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((Float, Float), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 2: enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 2: enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test5ClassV6methodSfyFTJrSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb3 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1 + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test5ClassV6methodSfyFTJpSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-NEXT: %[[#]] = argument of bb2 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-EMPTY: + // CHECK: sil private [ossa] @$s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector { // CHECK: bb0(%0 : $Float, %1 : @owned $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float, %4 : $Float, %5 : $Float): // CHECK: %[[#D8:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float @@ -506,6 +644,117 @@ bb3(%118 : $Builtin.FPIEEE32, %119 : $Builtin.FPIEEE32, %120 : $Builtin.FPIEEE32 sil hidden [ossa] @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { [global: ] bb0(%0 : $Float): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test14cond_tuple_varyS2fFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test14cond_tuple_varyS2fFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 2{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%0, %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %0 = argument of bb0 : $Float + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test14cond_tuple_varyS2fFTJrSpSr contains 3 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 1: enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 1: enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 2: enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 2: enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1_2 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (), ())) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test14cond_tuple_varyS2fFTJrSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb3 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1_2 + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test14cond_tuple_varyS2fFTJpSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-NEXT: %[[#]] = argument of bb2 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (), ()) + // TRUNNER-EMPTY: + // CHECK: sil private [ossa] @$s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float { // CHECK: bb0(%0 : $Float, %1 : @owned $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0): // CHECK: %[[#F2:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)