diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index bd8c8395d436a..a7ac8cef0fb5a 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -1039,9 +1039,11 @@ class AsyncPartialApplicationForwarderEmission : public PartialApplicationForwarderEmission { using super = PartialApplicationForwarderEmission; AsyncContextLayout layout; - llvm::Value *contextBuffer; + llvm::Value *calleeFunction; + llvm::Value *currentResumeFn; Size contextSize; Address context; + Address calleeContextBuffer; unsigned currentArgumentIndex; struct DynamicFunction { using Kind = DynamicFunctionKind; @@ -1060,23 +1062,11 @@ class AsyncPartialApplicationForwarderEmission }; Optional self = llvm::None; - llvm::Value *loadValue(ElementLayout layout) { - Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None); - auto &ti = cast(layout.getType()); - Explosion explosion; - ti.loadAsTake(subIGF, addr, explosion); - return explosion.claimNext(); - } void saveValue(ElementLayout layout, Explosion &explosion) { Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None); auto &ti = cast(layout.getType()); ti.initialize(subIGF, explosion, addr, /*isOutlined*/ false); } - void loadValue(ElementLayout layout, Explosion &explosion) { - Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None); - auto &ti = cast(layout.getType()); - ti.loadAsTake(subIGF, addr, explosion); - } public: AsyncPartialApplicationForwarderEmission( @@ -1099,9 +1089,59 @@ class AsyncPartialApplicationForwarderEmission void begin() override { super::begin(); } void mapAsyncParameters() override { - contextBuffer = origParams.claimNext(); - context = layout.emitCastTo(subIGF, contextBuffer); - args.add(contextBuffer); + // Ignore the original context. + (void)origParams.claimNext(); + + llvm::Value *dynamicContextSize32; + auto initialContextSize = Size(0); + std::tie(calleeFunction, dynamicContextSize32) = getAsyncFunctionAndSize( + subIGF, origType->getRepresentation(), *staticFnPtr, + nullptr, std::make_pair(true, true), initialContextSize); + auto *dynamicContextSize = + subIGF.Builder.CreateZExt(dynamicContextSize32, subIGF.IGM.SizeTy); + calleeContextBuffer = + emitAllocAsyncContext(subIGF, dynamicContextSize); + context = layout.emitCastTo(subIGF, calleeContextBuffer.getAddress()); + auto calleeContext = + layout.emitCastTo(subIGF, calleeContextBuffer.getAddress()); + args.add(subIGF.Builder.CreateBitOrPointerCast( + calleeContextBuffer.getAddress(), IGM.SwiftContextPtrTy)); + + // Set caller info into the context. + { // caller context + Explosion explosion; + auto fieldLayout = layout.getParentLayout(); + auto *context = subIGF.getAsyncContext(); + if (auto schema = + subIGF.IGM.getOptions().PointerAuth.AsyncContextParent) { + Address fieldAddr = + fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None); + auto authInfo = PointerAuthInfo::emit( + subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity()); + context = emitPointerAuthSign(subIGF, context, authInfo); + } + explosion.add(context); + saveValue(fieldLayout, explosion); + } + { // Return to caller function. + auto fieldLayout = layout.getResumeParentLayout(); + currentResumeFn = subIGF.Builder.CreateIntrinsicCall( + llvm::Intrinsic::coro_async_resume, {}); + auto fnVal = currentResumeFn; + // Sign the pointer. + if (auto schema = subIGF.IGM.getOptions().PointerAuth.AsyncContextResume) { + Address fieldAddr = + fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None); + auto authInfo = PointerAuthInfo::emit( + subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity()); + fnVal = emitPointerAuthSign(subIGF, fnVal, authInfo); + } + fnVal = subIGF.Builder.CreateBitCast( + fnVal, subIGF.IGM.TaskContinuationFunctionPtrTy); + Explosion explosion; + explosion.add(fnVal); + saveValue(fieldLayout, explosion); + } } void gatherArgumentsFromApply() override { super::gatherArgumentsFromApply(true); @@ -1127,13 +1167,87 @@ class AsyncPartialApplicationForwarderEmission // Nothing to do here. The error result pointer is already in the // appropriate position. } + FunctionPointer getFunctionPointerForDispatchCall(const FunctionPointer &fn) { + auto &IGM = subIGF.IGM; + // Strip off the return type. The original function pointer signature + // captured both the entry point type and the resume function type. + auto *fnTy = llvm::FunctionType::get( + IGM.VoidTy, fn.getSignature().getType()->params(), false /*vaargs*/); + auto signature = + Signature(fnTy, fn.getSignature().getAttributes(), IGM.SwiftAsyncCC); + auto fnPtr = + FunctionPointer(FunctionPointer::Kind::Function, fn.getRawPointer(), + fn.getAuthInfo(), signature); + return fnPtr; + } llvm::CallInst *createCall(FunctionPointer &fnPtr) override { - return subIGF.Builder.CreateCall(fnPtr.getAsFunction(subIGF), - args.claimAll()); + auto newFnPtr = FunctionPointer( + FunctionPointer::Kind::Function, fnPtr.getPointer(subIGF), + fnPtr.getAuthInfo(), Signature::forAsyncAwait(subIGF.IGM, origType)); + auto &Builder = subIGF.Builder; + + auto argValues = args.claimAll(); + + // Setup the suspend point. + SmallVector arguments; + auto signature = newFnPtr.getSignature(); + auto asyncContextIndex = signature.getAsyncContextIndex(); + auto paramAttributeFlags = + asyncContextIndex | + (signature.getAsyncResumeFunctionSwiftSelfIndex() << 8); + // Index of swiftasync context | ((index of swiftself) << 8). + arguments.push_back( + IGM.getInt32(paramAttributeFlags)); + arguments.push_back(currentResumeFn); + auto resumeProjFn = subIGF.getOrCreateResumePrjFn(); + arguments.push_back( + Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy)); + auto dispatchFn = subIGF.createAsyncDispatchFn( + getFunctionPointerForDispatchCall(newFnPtr), argValues); + arguments.push_back( + Builder.CreateBitOrPointerCast(dispatchFn, IGM.Int8PtrTy)); + arguments.push_back( + Builder.CreateBitOrPointerCast(newFnPtr.getRawPointer(), IGM.Int8PtrTy)); + if (auto authInfo = newFnPtr.getAuthInfo()) { + arguments.push_back(newFnPtr.getAuthInfo().getDiscriminator()); + } + for (auto arg : argValues) + arguments.push_back(arg); + auto resultTy = + cast(signature.getType()->getReturnType()); + return subIGF.emitSuspendAsyncCall(asyncContextIndex, resultTy, arguments); } void createReturn(llvm::CallInst *call) override { - call->setTailCallKind(IGM.AsyncTailCallKind); - subIGF.Builder.CreateRetVoid(); + emitDeallocAsyncContext(subIGF, calleeContextBuffer); + auto numAsyncContextParams = + Signature::forAsyncReturn(IGM, outType).getAsyncContextIndex() + 1; + llvm::Value *result = call; + auto *suspendResultTy = cast(result->getType()); + Explosion resultExplosion; + Explosion errorExplosion; + SILFunctionConventions conv(outType, subIGF.getSILModule()); + auto hasError = outType->hasErrorResult(); + + Optional> nativeResults = llvm::None; + SmallVector nativeResultsStorage; + + if (suspendResultTy->getNumElements() == numAsyncContextParams) { + // no result to forward. + assert(!hasError); + } else { + auto &Builder = subIGF.Builder; + auto resultTys = + makeArrayRef(suspendResultTy->element_begin() + numAsyncContextParams, + suspendResultTy->element_end()); + + for (unsigned i = 0, e = resultTys.size(); i != e; ++i) { + llvm::Value *elt = + Builder.CreateExtractValue(result, numAsyncContextParams + i); + nativeResultsStorage.push_back(elt); + } + nativeResults = nativeResultsStorage; + } + emitAsyncReturn(subIGF, layout, origType, nativeResults); } void end() override { assert(context.isValid()); @@ -1180,6 +1294,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM, llvm::AttributeList outAttrs = outSig.getAttributes(); llvm::FunctionType *fwdTy = outSig.getType(); SILFunctionConventions outConv(outType, IGM.getSILModule()); + Optional asyncLayout; StringRef FnName; if (staticFnPtr) @@ -1203,19 +1318,29 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM, IRGenFunction subIGF(IGM, fwd); if (origType->isAsync()) { - subIGF.setupAsync( - Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex()); + auto asyncContextIdx = + Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex(); + asyncLayout.emplace(irgen::getAsyncContextLayout( + IGM, origType, substType, subs, /*suppress generics*/ false, + FunctionPointer::Kind( + FunctionPointer::BasicKind::AsyncFunctionPointer))); + + subIGF.setupAsync(asyncContextIdx); - auto *calleeAFP = staticFnPtr->getDirectPointer(); + //auto *calleeAFP = staticFnPtr->getDirectPointer(); LinkEntity entity = LinkEntity::forPartialApplyForwarder(fwd); - auto size = Size(0); assert(!asyncFunctionPtr && "already had an async function pointer to the forwarder?!"); - asyncFunctionPtr = emitAsyncFunctionPointer(IGM, fwd, entity, size); + emitAsyncFunctionEntry(subIGF, *asyncLayout, entity, asyncContextIdx); + asyncFunctionPtr = + emitAsyncFunctionPointer(IGM, fwd, entity, asyncLayout->getSize()); + // TODO: if calleeAFP is definition: +#if 0 subIGF.Builder.CreateIntrinsicCall( llvm::Intrinsic::coro_async_size_replace, {subIGF.Builder.CreateBitCast(asyncFunctionPtr, IGM.Int8PtrTy), subIGF.Builder.CreateBitCast(calleeAFP, IGM.Int8PtrTy)}); +#endif } if (IGM.DebugInfo) IGM.DebugInfo->emitArtificialFunction(subIGF, fwd); @@ -1679,7 +1804,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM, llvm::CallInst *call = emission->createCall(fnPtr); - if (addressesToDeallocate.empty() && !needsAllocas && + if (!origType->isAsync() && addressesToDeallocate.empty() && !needsAllocas && (!consumesContext || !dependsOnContextLifetime)) call->setTailCall(); diff --git a/test/IRGen/async/partial_apply.sil b/test/IRGen/async/partial_apply.sil index 1398dbcc13994..374923e54e28e 100644 --- a/test/IRGen/async/partial_apply.sil +++ b/test/IRGen/async/partial_apply.sil @@ -506,3 +506,12 @@ bb0(%thick : $@callee_guaranteed @async @convention(thick) (Int64, Int32) -> Int // CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s45indirect_guaranteed_captured_class_pair_paramTA.{{[0-9]+}}"( // CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw_"( // CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw0_"( + +sil @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error) + +sil @dont_crash : $@convention(thin) @async (Int) -> @owned @async @callee_guaranteed (Int) -> (Int, @error Error) { +bb0(%0 : $Int): + %2 = function_ref @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error) + %3 = partial_apply [callee_guaranteed] %2(%0) : $@convention(thin) @async (Int, Int) -> (Int, @error Error) + return %3 : $@async @callee_guaranteed (Int) -> (Int, @error Error) +}