diff --git a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift index be811c64ebf6d..c1aee0a0168ce 100644 --- a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift @@ -107,7 +107,7 @@ private let verbose = false private func log(prefix: Bool = true, _ message: @autoclosure () -> String) { if verbose { - debugLog(prefix: prefix, message()) + debugLog(prefix: prefix, "[ADCS] " + message()) } } @@ -120,6 +120,352 @@ let generalClosureSpecialization = FunctionPass( print("NOT IMPLEMENTED") } +extension Collection { + func getExactlyOneOrNil() -> Element? { + assert(self.count <= 1) + return self.first + } +} + +extension Type: Hashable { + func isBranchTracingEnumIn(vjp: Function) -> Bool { + return self.bridged.isAutodiffBranchTracingEnumInVJP(vjp.bridged) + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(bridged.opaqueValue) + } +} + +extension BasicBlock { + fileprivate func getBranchTracingEnumArg(vjp: Function) -> Argument? { + return self.arguments.filter { $0.type.isBranchTracingEnumIn(vjp: vjp) }.getExactlyOneOrNil() + } +} + +extension Decl { + public var asDecl: BridgedDecl { BridgedDecl(raw: bridged.obj) } +} + +extension BridgedDecl { + public var asDeclObj: BridgedDeclObj { + BridgedDeclObj(SwiftObject(raw.bindMemory(to: BridgedSwiftObject.self, capacity: 1))) + } + public var decl: Decl { asDeclObj.decl } +} + +typealias ClosureInfoMultiBB = ( + closure: SingleValueInstruction, + capturedArgs: [Value], + subsetThunk: PartialApplyInst?, + payloadTuple: TupleInst, + idxInPayload: Int, + enumTypeAndCase: EnumTypeAndCase +) + +typealias EnumTypeAndCase = ( + enumType: Type, + caseIdx: Int +) + +typealias BTEToPredsDict = [SIL.`Type`: [SIL.`Type`]] + +typealias BTECaseToClosureListDict = [Int: [ClosureInfoMultiBB]] + +typealias SpecBTEDict = [SIL.`Type`: SIL.`Type`] + +func getCapturedArgTypesTupleForClosure( + closure: SingleValueInstruction, context: FunctionPassContext +) -> AST.`Type` { + var capturedArgTypes = [AST.`Type`]() + if let pai = closure as? PartialApplyInst { + for arg in pai.arguments { + capturedArgTypes.append(arg.type.rawType) + } + } else if closure as? ThinToThickFunctionInst != nil { + // Nothing captured + } else { + assert(false) + } + return context.getTupleType(elements: capturedArgTypes) +} + +func getBranchTracingEnumPreds(bteType: SIL.`Type`, vjp: Function) -> [SIL.`Type`] { + var btePreds = [SIL.`Type`]() + guard let enumCases = bteType.getEnumCases(in: vjp) else { + return btePreds + } + for enumCase in enumCases { + let payloadType: SIL.`Type` = enumCase.payload! + if payloadType.tupleElements.count == 0 { + continue + } + let firstTupleElementType: SIL.`Type` = payloadType.tupleElements[0] + if firstTupleElementType.isBranchTracingEnumIn(vjp: vjp) { + btePreds.append(firstTupleElementType) + } + } + return btePreds +} + +func iterateOverBranchTracingEnumPreds( + bteToPredsDict: inout BTEToPredsDict, + currentBTEType: SIL.`Type`, + vjp: Function +) { + let currentBTEPreds: [SIL.`Type`] = getBranchTracingEnumPreds(bteType: currentBTEType, vjp: vjp) + bteToPredsDict[currentBTEType] = currentBTEPreds + for currentBTEPred in currentBTEPreds { + if !bteToPredsDict.keys.contains(currentBTEPred) { + iterateOverBranchTracingEnumPreds( + bteToPredsDict: &bteToPredsDict, currentBTEType: currentBTEPred, vjp: vjp) + } + } +} + +func getBranchTracingEnumQueue(topBTEType: SIL.`Type`, vjp: Function) -> [SIL.`Type`] { + var bteToPredsDict = BTEToPredsDict() + iterateOverBranchTracingEnumPreds( + bteToPredsDict: &bteToPredsDict, + currentBTEType: topBTEType, + vjp: vjp) + var bteQueue = [SIL.`Type`]() + let totalEnums = bteToPredsDict.count + + for i in 0.. SIL.`Type` { + var silType = ty + if silType.rawType.hasArchetype { + silType = silType.mapTypeOutOfContext() + } + let remappedCanType = silType.rawType.getReducedType( + sig: function.loweredFunctionType.SILFunctionType_getSubstGenericSignature().genericSignature) + let remappedSILType = Type.getPrimitiveType( + canType: remappedCanType, silValueCategory: silType.category) + if !function.genericSignature.isEmpty { + return function.mapTypeIntoContext(remappedSILType) + } + return remappedSILType +} + +func getBranchingTraceEnumLoweredType(ed: Decl, vjp: Function) -> SIL.`Type` { + ed.bridged.getAs(NominalTypeDecl.self).declaredInterfaceType.canonical.loweredType(in: vjp) +} + +func getSourceFileFor(derivative: Function) -> BridgedSourceFile { + if let sourceFileRawPtr = derivative.sourceFile.raw { + return BridgedSourceFile(raw: sourceFileRawPtr) + } + return derivative.bridged.getFilesForModule().withElements(ofType: BridgedFileUnit.self) { + for fileUnit in $0 { + if let sourceFileRawPtr = fileUnit.castToSourceFile().raw { + return BridgedSourceFile?(BridgedSourceFile(raw: sourceFileRawPtr)) + } + } + assert(false) + return nil + }! +} + +func cloneGenericParameters( + astContext: BridgedASTContext, declContext: BridgedDeclContext, canGenericSig: CanGenericSignature +) -> BridgedGenericParamList { + var params = [BridgedGenericTypeParamDecl]() + for type in canGenericSig.genericSignature.genericParameters { + assert(type.isGenericTypeParameter) + params.append( + BridgedGenericTypeParamDecl.createImplicit( + declContext: declContext, + name: type.GenericTypeParam_getName(), + depth: type.GenericTypeParam_getDepth(), + index: type.GenericTypeParam_getIndex(), + paramKind: type.GenericTypeParam_getParamKind())) + } + return params.withBridgedArrayRef { + BridgedGenericParamList.createParsed( + astContext, leftAngleLoc: swift.SourceLoc(raw: nil), parameters: $0, + genericWhereClause: BridgedNullableTrailingWhereClause(raw: nil), + rightAngleLoc: swift.SourceLoc(raw: nil)) + } +} + +func autodiffSpecializeBranchTracingEnum( + bteType: SIL.`Type`, topVJP: Function, + bteCaseToClosureListDict: BTECaseToClosureListDict, + specBTEDict: [SIL.`Type`: SIL.`Type`], + context: FunctionPassContext +) -> SIL.`Type` { + assert(specBTEDict[bteType] == nil) + + let oldED = bteType.nominal as! EnumDecl + let declContext = oldED.asDecl.getDeclContext() + let astContext = declContext.astContext + + var newEDNameStr: String = oldED.name.string + "_spec" + var newPLs = [BridgedParameterList]() + + for enumCase in bteType.getEnumCases(in: topVJP)! { + let oldPayloadTupleType: Type = enumCase.payload! + let oldEED: EnumElementDecl = enumCase.enumElementDecl + + let oldPL: BridgedParameterList = oldEED.parameterList + assert(oldPL.size == 1) + let oldPD: BridgedParamDecl = oldPL.get(0) + + let closureInfosMultiBB: [ClosureInfoMultiBB] = bteCaseToClosureListDict[enumCase.index] ?? [] + + var newECDNameSuffix: String = "" + var newPayloadTupleElementTypes = [AST.`Type`]() + var labels = [swift.Identifier]() + + for idxInPayloadTuple in 0.. SpecBTEDict { + let bteQueue: [SIL.`Type`] = getBranchTracingEnumQueue(topBTEType: topBTE, vjp: topVJP) + + var specBTEDict = [SIL.`Type`: SIL.`Type`]() + for bteType in bteQueue { + let ed = bteType.nominal as! EnumDecl + let silType = remapType( + ty: getBranchingTraceEnumLoweredType(ed: ed, vjp: topVJP), function: topVJP) + + var bteCaseToClosureListDict = BTECaseToClosureListDict() + for closureInfoMultiBB in closureInfosMultiBB { + if closureInfoMultiBB.enumTypeAndCase.enumType != bteType { + continue + } + if bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx] == nil { + bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx] = [] + } + bteCaseToClosureListDict[closureInfoMultiBB.enumTypeAndCase.caseIdx]!.append( + closureInfoMultiBB) + } + + specBTEDict[silType] = autodiffSpecializeBranchTracingEnum( + bteType: silType, topVJP: topVJP, bteCaseToClosureListDict: bteCaseToClosureListDict, + specBTEDict: specBTEDict, context: context) + } + + return specBTEDict +} + let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-specialization") { (function: Function, context: FunctionPassContext) in @@ -175,6 +521,90 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special private let specializationLevelLimit = 2 +private func getPartialApplyOfPullbackInExitVJPBB(vjp: Function) -> PartialApplyInst? { + log("getPartialApplyOfPullbackInExitVJPBB: running for VJP \(vjp.name)") + guard let exitBB = vjp.blocks.filter({ $0.terminator as? ReturnInst != nil }).getExactlyOneOrNil() + else { + log("getPartialApplyOfPullbackInExitVJPBB: exit BB not found, aborting") + return nil + } + + let ri = exitBB.terminator as! ReturnInst + guard let retValDefiningInstr = ri.returnedValue.definingInstruction else { + log( + "getPartialApplyOfPullbackInExitVJPBB: return value is not defined by an instruction, aborting" + ) + return nil + } + + func handleConvertFunctionOrPartialApply(inst: Instruction) -> PartialApplyInst? { + if let pai = inst as? PartialApplyInst { + log("getPartialApplyOfPullbackInExitVJPBB: success") + return pai + } + if let cfi = inst as? ConvertFunctionInst { + if let pai = cfi.fromFunction as? PartialApplyInst { + log("getPartialApplyOfPullbackInExitVJPBB: success") + return pai + } + log( + "getPartialApplyOfPullbackInExitVJPBB: fromFunction operand of convert_function instruction is not defined by partial_apply instruction, aborting" + ) + return nil + } + log("getPartialApplyOfPullbackInExitVJPBB: unexpected instruction type, aborting") + return nil + } + + if let ti = retValDefiningInstr as? TupleInst { + log("getPartialApplyOfPullbackInExitVJPBB: return value is defined by tuple instruction") + if ti.operands.count < 2 { + log( + "getPartialApplyOfPullbackInExitVJPBB: tuple instruction has \(ti.operands.count) operands, but at least 2 expected, aborting" + ) + return nil + } + guard let lastTupleElemDefiningInst = ti.operands.last!.value.definingInstruction else { + log( + "getPartialApplyOfPullbackInExitVJPBB: last tuple element is not defined by an instruction, aborting" + ) + return nil + } + return handleConvertFunctionOrPartialApply(inst: lastTupleElemDefiningInst) + } + + return handleConvertFunctionOrPartialApply(inst: retValDefiningInstr) +} + +private func getPullbackClosureInfoMultiBB(in vjp: Function, _ context: FunctionPassContext) + -> PullbackClosureInfo +{ + let paiOfPbInExitVjpBB = getPartialApplyOfPullbackInExitVJPBB(vjp: vjp)! + var pullbackClosureInfo = PullbackClosureInfo(paiOfPullback: paiOfPbInExitVjpBB) + var subsetThunkArr = [SingleValueInstruction]() + + for inst in vjp.instructions { + if inst == paiOfPbInExitVjpBB { + continue + } + if inst.asSupportedClosure == nil { + continue + } + + let rootClosure = inst.asSupportedClosure! + if subsetThunkArr.contains(rootClosure) { + continue + } + + let closureInfoArr = handleNonAppliesMultiBB(for: rootClosure, context) + pullbackClosureInfo.closureInfosMultiBB.append(contentsOf: closureInfoArr) + subsetThunkArr.append( + contentsOf: closureInfoArr.filter { $0.subsetThunk != nil }.map { $0.subsetThunk! }) + } + + return pullbackClosureInfo +} + private func getPullbackClosureInfo(in caller: Function, _ context: FunctionPassContext) -> PullbackClosureInfo? { @@ -374,6 +804,162 @@ private func updatePullbackClosureInfo( intermediateClosureArgDescriptorData: intermediateClosureArgDescriptorData, context) } +typealias BTEPayloadArgOfPbBBWithBTETypeAndCase = (arg: Argument, enumTypeAndCase: EnumTypeAndCase) + +// If the pullback's basic block has an argument which is a payload tuple of the +// branch tracing enum corresponding to the given VJP, return this argument and any valid combination +// of a branch tracing enum type and its case index having the same payload tuple type as the argument. +// The function assumes that no more than one such argument is present. +private func getBTEPayloadArgOfPbBBWithBTETypeAndCase(_ bb: BasicBlock, vjp: Function) + -> BTEPayloadArgOfPbBBWithBTETypeAndCase? +{ + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: basic block \(bb.shortDescription) in pullback \(bb.parentFunction.name)" + ) + guard let predBB = bb.predecessors.first else { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: the bb has no predecessors, aborting") + return nil + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: start iterating over bb args") + for arg in bb.arguments { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: \(arg)") + if !arg.type.isTuple { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: arg is not a tuple, skipping") + continue + } + + if let bi = predBB.terminator as? BranchInst { + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is branch instruction") + guard let uedi = bi.operands[arg.index].value.definingInstruction as? UncheckedEnumDataInst + else { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: operand corresponding to the argument is not defined by unchecked_enum_data instruction" + ) + continue + } + let enumType = uedi.`enum`.type + if !enumType.isBranchTracingEnumIn(vjp: vjp) { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)" + ) + continue + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success") + return BTEPayloadArgOfPbBBWithBTETypeAndCase( + arg: arg, + enumTypeAndCase: ( + enumType: enumType, + caseIdx: uedi.caseIndex + ) + ) + } + + if let sei = predBB.terminator as? SwitchEnumInst { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: terminator of pred bb is switch_enum instruction" + ) + let enumType = sei.enumOp.type + if !enumType.isBranchTracingEnumIn(vjp: vjp) { + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: enum type \(enumType) is not a branch tracing enum in VJP \(vjp.name)" + ) + continue + } + + log("getBTEPayloadArgOfPbBBWithBTETypeAndCase: success") + return BTEPayloadArgOfPbBBWithBTETypeAndCase( + arg: arg, + enumTypeAndCase: ( + enumType: enumType, + caseIdx: sei.getUniqueCase(forSuccessor: bb)! + ) + ) + } + } + + log( + "getBTEPayloadArgOfPbBBWithBTETypeAndCase: finish iterating over bb args; branch tracing enum arg not found" + ) + return nil +} + +extension PartialApplyInst { + func isSubsetThunk() -> Bool { + if self.argumentOperands.singleElement == nil { + return false + } + guard let function = self.referencedFunction else { + return false + } + return function.bridged.isAutodiffSubsetParametersThunk() + } +} + +private func handleNonAppliesMultiBB( + for rootClosure: SingleValueInstruction, + _ context: FunctionPassContext +) + -> [ClosureInfoMultiBB] +{ + log("handleNonAppliesMultiBB: running for \(rootClosure)") + let vjp = rootClosure.parentFunction + var closureInfoArr = [ClosureInfoMultiBB]() + + var closure = rootClosure + var subsetThunk = PartialApplyInst?(nil) + if rootClosure.uses.singleElement != nil { + if let pai = closure.uses.singleElement!.instruction as? PartialApplyInst { + if pai.isSubsetThunk() { + log("handleNonAppliesMultiBB: found subset thunk \(pai)") + subsetThunk = pai + closure = pai + } + } + } + + for use in closure.uses { + guard let ti = use.instruction as? TupleInst else { + log("handleNonAppliesMultiBB: unexpected use of closure, aborting: \(use)") + return [] + } + for tiUse in ti.uses { + guard let ei = tiUse.instruction as? EnumInst else { + log("handleNonAppliesMultiBB: unexpected use of payload tuple, aborting: \(tiUse)") + return [] + } + if !ei.type.isBranchTracingEnumIn(vjp: vjp) { + log( + "handleNonAppliesMultiBB: enum type \(ei.type) is not a branch tracing enum in VJP \(vjp.name), aborting" + ) + return [] + } + var capturedArgs = [Value]() + if let pai = rootClosure as? PartialApplyInst { + capturedArgs = pai.argumentOperands.map { $0.value } + } + log( + "handleNonAppliesMultiBB: creating closure info with enum type \(ei.type), case index \(ei.caseIndex), index in payload tuple \(use.index) and payload tuple \(ti)" + ) + let enumTypeAndCase = (enumType: ei.type, caseIdx: ei.caseIndex) + closureInfoArr.append( + ClosureInfoMultiBB( + closure: rootClosure, + capturedArgs: capturedArgs, + subsetThunk: subsetThunk, + payloadTuple: ti, + idxInPayload: use.index, + enumTypeAndCase: enumTypeAndCase + )) + } + } + log( + "handleNonAppliesMultiBB: created \(closureInfoArr.count) closure info entries for \(rootClosure)" + ) + return closureInfoArr +} + /// Handles all non-apply direct and transitive uses of `rootClosure`. /// /// Returns: @@ -1391,6 +1977,7 @@ private struct ClosureArgDescriptor { private struct PullbackClosureInfo { let paiOfPullback: PartialApplyInst var closureArgDescriptors: [ClosureArgDescriptor] = [] + var closureInfosMultiBB: [ClosureInfoMultiBB] = [] init(paiOfPullback: PartialApplyInst) { self.paiOfPullback = paiOfPullback @@ -1474,3 +2061,131 @@ let rewrittenCallerBodyTest = FunctionTest("autodiff_closure_specialize_rewritte print("Rewritten caller body for: \(function.name):") print("\(function)\n") } + +let getPullbackClosureInfoMultiBBTest = FunctionTest( + "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" +) { + function, arguments, context in + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context) + print("Run getPullbackClosureInfoMultiBB for VJP \(function.name): pullbackClosureInfo = (") + print(" pullbackFn = \(pullbackClosureInfo.pullbackFn.name)") + print(" closureInfosMultiBB = [") + for closureInfoMultiBB in pullbackClosureInfo.closureInfosMultiBB { + print(" ClosureInfoMultiBB(") + print(" closure: \(closureInfoMultiBB.closure)") + print(" capturedArgs: [") + for capturedArg in closureInfoMultiBB.capturedArgs { + print(" \(capturedArg)") + } + print(" ]") + let subsetThunkStr = + (closureInfoMultiBB.subsetThunk == nil ? "nil" : "\(closureInfoMultiBB.subsetThunk!)") + print(" subsetThunk: \(subsetThunkStr)") + print(" payloadTuple: \(closureInfoMultiBB.payloadTuple)") + print(" idxInPayload: \(closureInfoMultiBB.idxInPayload)") + print(" enumTypeAndCase: \(closureInfoMultiBB.enumTypeAndCase)") + print(" )") + } + print(" ]\n)\n") +} + +func getSpecBTEDict(vjp: Function, context: FunctionPassContext) -> SpecBTEDict { + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: vjp, context) + let pb = pullbackClosureInfo.pullbackFn + let enumTypeOfEntryBBArg = pb.entryBlock.getBranchTracingEnumArg(vjp: vjp)!.type + let enumDict = autodiffSpecializeBranchTracingEnums( + topVJP: vjp, topBTE: enumTypeOfEntryBBArg, + closureInfosMultiBB: pullbackClosureInfo.closureInfosMultiBB, context: context) + return enumDict +} + +func specializeBranchTracingEnumBBArgInVJP(arg: Argument, specBTEDict: SpecBTEDict) -> Argument { + let oldOwnership = arg.bridged.getOwnership() + let bb = arg.parentBlock + let index = arg.index + assert(specBTEDict[arg.type] != nil) + let newType = specBTEDict[arg.type]! + let newArg = bb.bridged.insertPhiArgument(index, newType.bridged, oldOwnership).argument + return newArg +} + +extension SIL.`Type`: Comparable { + public static func < (lhs: SIL.`Type`, rhs: SIL.`Type`) -> Bool { + return "\(lhs)" < "\(rhs)" + } +} + +let specializeBranchTracingEnums = FunctionTest("autodiff_specialize_branch_tracing_enums") { + function, arguments, context in + let enumDict = getSpecBTEDict(vjp: function, context: context) + print( + "Specialized branch tracing enum dict for VJP \(function.name) contains \(enumDict.count) elements:" + ) + + var keys = [SIL.`Type`](enumDict.keys) + keys.sort() + for (idx, key) in keys.enumerated() { + print("non-specialized BTE \(idx): \(key.nominal!.description)") + print("specialized BTE \(idx): \(enumDict[key]!.nominal!.description)") + } + print("") +} + +let specializeBTEArgInVjpBB = FunctionTest("autodiff_specialize_bte_arg_in_vjp_bb") { + function, arguments, context in + let enumDict = getSpecBTEDict(vjp: function, context: context) + print("Specialized BTE arguments of basic blocks in VJP \(function.name):") + for bb in function.blocks { + guard let arg = bb.getBranchTracingEnumArg(vjp: function) else { + continue + } + let newArg = specializeBranchTracingEnumBBArgInVJP(arg: arg, specBTEDict: enumDict) + print("\(newArg)") + bb.eraseArgument(at: newArg.index, context) + } + print("") +} + +func specializePayloadTupleBBArgInPullback( + arg: Argument, + enumTypeAndCase: EnumTypeAndCase +) -> Argument { + let bb = arg.parentBlock + let newEnumType = enumTypeAndCase.enumType + + var newPayloadTupleTy = SIL.`Type`?(nil) + for enumCase in newEnumType.getEnumCases(in: arg.parentFunction)! { + if enumCase.index == enumTypeAndCase.caseIdx { + newPayloadTupleTy = enumCase.payload! + break + } + } + assert(newPayloadTupleTy != nil) + + let oldOwnership = arg.bridged.getOwnership() + return bb.bridged.insertPhiArgument(arg.index, newPayloadTupleTy!.bridged, oldOwnership).argument +} + +let specializePayloadArgInPullbackBB = FunctionTest("autodiff_specialize_payload_arg_in_pb_bb") { + function, arguments, context in + let pullbackClosureInfo = getPullbackClosureInfoMultiBB(in: function, context) + let pb = pullbackClosureInfo.pullbackFn + let enumDict = getSpecBTEDict(vjp: function, context: context) + + print("Specialized BTE payload arguments of basic blocks in pullback \(pb.name):") + for bb in pb.blocks { + guard + let (arg, enumTypeAndCase) = getBTEPayloadArgOfPbBBWithBTETypeAndCase(bb, vjp: function) + else { + continue + } + + let enumType = enumDict[enumTypeAndCase.enumType]! + let newArg = specializePayloadTupleBBArgInPullback( + arg: arg, + enumTypeAndCase: (enumType: enumType, caseIdx: enumTypeAndCase.caseIdx)) + print("\(newArg)") + bb.eraseArgument(at: newArg.index, context) + } + print("") +} diff --git a/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift b/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift index 984d9609e4eda..e15e69cc2e057 100644 --- a/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift +++ b/SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift @@ -44,6 +44,7 @@ public func registerOptimizerTests() { addressOwnershipLiveRangeTest, argumentConventionsTest, getPullbackClosureInfoTest, + getPullbackClosureInfoMultiBBTest, interiorLivenessTest, lifetimeDependenceRootTest, lifetimeDependenceScopeTest, @@ -53,6 +54,9 @@ public func registerOptimizerTests() { localVariableReachingAssignmentsTest, rangeOverlapsPathTest, rewrittenCallerBodyTest, + specializeBranchTracingEnums, + specializeBTEArgInVjpBB, + specializePayloadArgInPullbackBB, specializedFunctionSignatureAndBodyTest, variableIntroducerTest ) diff --git a/include/swift/SIL/SILBridging.h b/include/swift/SIL/SILBridging.h index 77a2a7613d13e..d0e0e64ebcab2 100644 --- a/include/swift/SIL/SILBridging.h +++ b/include/swift/SIL/SILBridging.h @@ -584,6 +584,7 @@ struct BridgedFunction { bool isTrapNoReturn() const; bool isConvertPointerToPointerArgument() const; bool isAutodiffVJP() const; + bool isAutodiffSubsetParametersThunk() const; SwiftInt specializationLevel() const; SWIFT_IMPORT_UNSAFE BridgedSubstitutionMap getMethodSubstitutions(BridgedSubstitutionMap contextSubs, BridgedCanType selfType) const; diff --git a/lib/SILOptimizer/Utils/OptimizerBridging.cpp b/lib/SILOptimizer/Utils/OptimizerBridging.cpp index 5b0cdef492b6a..15588b27a6853 100644 --- a/lib/SILOptimizer/Utils/OptimizerBridging.cpp +++ b/lib/SILOptimizer/Utils/OptimizerBridging.cpp @@ -13,6 +13,7 @@ #include "swift/SILOptimizer/OptimizerBridging.h" #include "../../IRGen/IRGenModule.h" #include "swift/AST/SemanticAttrs.h" +#include "swift/Demangling/ManglingMacros.h" #include "swift/SIL/DynamicCasts.h" #include "swift/SIL/OSSACompleteLifetime.h" #include "swift/SIL/SILCloner.h" @@ -527,6 +528,73 @@ bool BridgedFunction::isAutodiffVJP() const { getFunction(), swift::AutoDiffFunctionComponent::VJP); } +bool BridgedFunction::isAutodiffSubsetParametersThunk() const { + Demangle::Context Ctx; + if (auto *root = Ctx.demangleSymbolAsNode(getFunction()->getName())) { + // root node has Global kind, the AutoDiffSubsetParametersThunk node (if + // present) is direct child of root. + return root->findByKind(Demangle::Node::Kind::AutoDiffSubsetParametersThunk, + /*maxDepth=*/1) != nullptr; + } + return false; +} + +// See also ASTMangler::mangleAutoDiffGeneratedDeclaration. +bool BridgedType::isAutodiffBranchTracingEnumInVJP(BridgedFunction vjp) const { + assert(vjp.isAutodiffVJP()); + EnumDecl *ed = unbridged().getEnumOrBoundGenericEnum(); + if (ed == nullptr) + return false; + + llvm::StringRef edName = ed->getNameStr(); + if (!edName.starts_with("_AD__")) + return false; + if (!llvm::StringRef(edName.data() + 5, edName.size() - 5) + .starts_with(MANGLING_PREFIX_STR)) + return false; + + // At this point, we know that the type is indeed a branch tracing enum. + // Now we need to ensure that it is the enum related to the given VJP. + + std::size_t idx = edName.rfind("__Pred__"); + assert(idx != std::string::npos); + + // Before "__Pred__", we have "_bbX", where X is a number. + // The loop calculates the start position of X. + for (; idx != 0 && std::isdigit(edName[idx - 1]); --idx) + ; + + assert(std::isdigit(edName[idx])); + assert(!std::isdigit(edName[idx - 1])); + + // The branch tracing enum decl name has the following components: + // 1) "_AD__"; + // 2) MANGLING_PREFIX; + // 3) original function name; + // 4) "_bb"; + // 5) X at position idx (see above); + // 6) the rest of the enum decl name. + // Thus, "_AD__", MANGLING_PREFIX and "_bb" must have total length less than + // idx. + std::size_t manglingPrefixSize = std::strlen(MANGLING_PREFIX_STR); + assert(idx > 5 + manglingPrefixSize + 3); + assert(std::string_view(edName.data() + idx - 3, 3) == "_bb"); + assert(std::string_view(edName.data(), 5 + manglingPrefixSize) == "_AD__$s"); + + llvm::StringRef enumOrigFuncName = + std::string_view(edName.data() + 5 + manglingPrefixSize, + idx - (5 + manglingPrefixSize + 3)); + + Demangle::Context Ctx; + if (auto *root = Ctx.demangleSymbolAsNode(vjp.getFunction()->getName())) + if (auto *node = + root->findByKind(Demangle::Node::Kind::Function, /*maxDepth=*/3)) + if (mangleNode(node).result() == enumOrigFuncName) + return true; + + return false; +} + SwiftInt BridgedFunction::specializationLevel() const { return swift::getSpecializationLevel(getFunction()); } diff --git a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil index ea361e96a6832..698545bf3790f 100644 --- a/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil +++ b/test/AutoDiff/SILOptimizer/closure_specialization/multi_bb_bte.sil @@ -45,6 +45,35 @@ bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $@ // reverse-mode derivative of mul42(_:) sil hidden @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional) -> (Float, @owned @callee_guaranteed (Float) -> Optional.TangentVector) { bb0(%0 : $Optional): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test5mul42yS2fSgFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test5mul42yS2fSgFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test5mul42yS2fSgFTJrSpSr contains 1 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(()) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0_spec { + // TRUNNER-NEXT: case bb0(()) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test5mul42yS2fSgFTJrSpSr: + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test5mul42yS2fSgFTJpSpSr: + // TRUNNER-EMPTY: + //=========== Test callsite and closure gathering logic ===========// specify_test "autodiff_closure_specialize_get_pullback_closure_info" // TRUNNER-LABEL: Specializing closures in function: $s4test5mul42yS2fSgFTJrSpSr @@ -299,6 +328,114 @@ bb6(%122 : $Builtin.FPIEEE32): // reverse-mode derivative of Class.method() sil hidden @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { bb0(%0 : $Class): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test5ClassV6methodSfyFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test5ClassV6methodSfyFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct_extract %0 : $Class, #Class.stored + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct_extract %0 : $Class, #Class.stored + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Class.TangentVector) -> (Float, Optional.TangentVector) to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> Float, %[[#]] : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional.TangentVector)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $Float + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test5ClassV6methodSfyFTJrSpSr contains 3 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((Float, Float), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 1: enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional.TangentVector))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 1: enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((Float, Float), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 2: enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 2: enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test5ClassV6methodSfyFTJrSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb3 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1 + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test5ClassV6methodSfyFTJpSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-NEXT: %[[#]] = argument of bb2 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-EMPTY: + //=========== Test callsite and closure gathering logic ===========// specify_test "autodiff_closure_specialize_get_pullback_closure_info" // TRUNNER-LABEL: Specializing closures in function: $s4test5ClassV6methodSfyFTJrSpSr @@ -568,6 +705,117 @@ bb3(%118 : $Builtin.FPIEEE32, %119 : $Builtin.FPIEEE32, %120 : $Builtin.FPIEEE32 sil hidden @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { [global: ] bb0(%0 : $Float): + specify_test "autodiff_closure_specialize_get_pullback_closure_info_multi_bb" + // TRUNNER-LABEL: Run getPullbackClosureInfoMultiBB for VJP $s4test14cond_tuple_varyS2fFTJrSpSr: pullbackClosureInfo = ( + // TRUNNER-NEXT: pullbackFn = $s4test14cond_tuple_varyS2fFTJpSpSr + // TRUNNER-NEXT: closureInfosMultiBB = [ + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 0{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple (%[[#]] : $@callee_guaranteed (Float) -> (Float, Float), %[[#]] : $@callee_guaranteed (Float) -> (Float, Float)) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: nil + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%[[#]], %[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 2{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 1) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ClosureInfoMultiBB( + // TRUNNER-NEXT: closure: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%0, %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float) + // TRUNNER-NEXT: capturedArgs: [ + // TRUNNER-NEXT: %0 = argument of bb0 : $Float + // TRUNNER-NEXT: %[[#]] = struct $Float (%[[#]] : $Builtin.FPIEEE32) + // TRUNNER-NEXT: ] + // TRUNNER-NEXT: subsetThunk: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float + // TRUNNER-NEXT: payloadTuple: %[[#]] = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%[[#]], %[[#]]) + // TRUNNER-NEXT: idxInPayload: 1{{$}} + // TRUNNER-NEXT: enumTypeAndCase: (enumType: $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, caseIdx: 0) + // TRUNNER-NEXT: ){{$}} + // TRUNNER-NEXT: ]{{$}} + // TRUNNER-NEXT: ){{$}} + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enums ===========// + specify_test "autodiff_specialize_branch_tracing_enums" + // TRUNNER-LABEL: Specialized branch tracing enum dict for VJP $s4test14cond_tuple_varyS2fFTJrSpSr contains 3 elements: + // TRUNNER-NEXT: non-specialized BTE 0: enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 0: enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 1: enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 1: enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1 { + // TRUNNER-NEXT: case bb0(((), ())) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: non-specialized BTE 2: enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, (Float) -> Float)) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float), (Float) -> (Float, Float))) + // TRUNNER-NEXT: } + // TRUNNER-NEXT: specialized BTE 2: enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1_2 { + // TRUNNER-NEXT: case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float))) + // TRUNNER-NEXT: case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (), ())) + // TRUNNER-NEXT: } + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum args in VJP ===========// + specify_test "autodiff_specialize_bte_arg_in_vjp_bb" + // TRUNNER-LABEL: Specialized BTE arguments of basic blocks in VJP $s4test14cond_tuple_varyS2fFTJrSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb3 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0_spec_bb2_1_bb1_1_2 + // TRUNNER-EMPTY: + + //=========== Test specialized branch tracing enum payload args in pullback ===========// + specify_test "autodiff_specialize_payload_arg_in_pb_bb" + // TRUNNER-LABEL: Specialized BTE payload arguments of basic blocks in pullback $s4test14cond_tuple_varyS2fFTJpSpSr: + // TRUNNER-NEXT: %[[#]] = argument of bb1 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0_spec_bb0_0_1, (Float, Float)) + // TRUNNER-NEXT: %[[#]] = argument of bb2 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0_spec_bb0_0_1, (), ()) + // TRUNNER-EMPTY: + //=========== Test callsite and closure gathering logic ===========// specify_test "autodiff_closure_specialize_get_pullback_closure_info" // TRUNNER-LABEL: Specializing closures in function: $s4test14cond_tuple_varyS2fFTJrSpSr