Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 151 additions & 26 deletions lib/IRGen/GenFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,9 +1039,11 @@ class AsyncPartialApplicationForwarderEmission
: public PartialApplicationForwarderEmission {
using super = PartialApplicationForwarderEmission;
AsyncContextLayout layout;
llvm::Value *contextBuffer;
llvm::Value *calleeFunction;
llvm::Value *currentResumeFn;
Size contextSize;
Address context;
Address calleeContextBuffer;
unsigned currentArgumentIndex;
struct DynamicFunction {
using Kind = DynamicFunctionKind;
Expand All @@ -1060,23 +1062,11 @@ class AsyncPartialApplicationForwarderEmission
};
Optional<Self> self = llvm::None;

llvm::Value *loadValue(ElementLayout layout) {
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(layout.getType());
Explosion explosion;
ti.loadAsTake(subIGF, addr, explosion);
return explosion.claimNext();
}
void saveValue(ElementLayout layout, Explosion &explosion) {
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(layout.getType());
ti.initialize(subIGF, explosion, addr, /*isOutlined*/ false);
}
void loadValue(ElementLayout layout, Explosion &explosion) {
Address addr = layout.project(subIGF, context, /*offsets*/ llvm::None);
auto &ti = cast<LoadableTypeInfo>(layout.getType());
ti.loadAsTake(subIGF, addr, explosion);
}

public:
AsyncPartialApplicationForwarderEmission(
Expand All @@ -1099,9 +1089,59 @@ class AsyncPartialApplicationForwarderEmission
void begin() override { super::begin(); }

void mapAsyncParameters() override {
contextBuffer = origParams.claimNext();
context = layout.emitCastTo(subIGF, contextBuffer);
args.add(contextBuffer);
// Ignore the original context.
(void)origParams.claimNext();

llvm::Value *dynamicContextSize32;
auto initialContextSize = Size(0);
std::tie(calleeFunction, dynamicContextSize32) = getAsyncFunctionAndSize(
subIGF, origType->getRepresentation(), *staticFnPtr,
nullptr, std::make_pair(true, true), initialContextSize);
auto *dynamicContextSize =
subIGF.Builder.CreateZExt(dynamicContextSize32, subIGF.IGM.SizeTy);
calleeContextBuffer =
emitAllocAsyncContext(subIGF, dynamicContextSize);
context = layout.emitCastTo(subIGF, calleeContextBuffer.getAddress());
auto calleeContext =
layout.emitCastTo(subIGF, calleeContextBuffer.getAddress());
args.add(subIGF.Builder.CreateBitOrPointerCast(
calleeContextBuffer.getAddress(), IGM.SwiftContextPtrTy));

// Set caller info into the context.
{ // caller context
Explosion explosion;
auto fieldLayout = layout.getParentLayout();
auto *context = subIGF.getAsyncContext();
if (auto schema =
subIGF.IGM.getOptions().PointerAuth.AsyncContextParent) {
Address fieldAddr =
fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None);
auto authInfo = PointerAuthInfo::emit(
subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity());
context = emitPointerAuthSign(subIGF, context, authInfo);
}
explosion.add(context);
saveValue(fieldLayout, explosion);
}
{ // Return to caller function.
auto fieldLayout = layout.getResumeParentLayout();
currentResumeFn = subIGF.Builder.CreateIntrinsicCall(
llvm::Intrinsic::coro_async_resume, {});
auto fnVal = currentResumeFn;
// Sign the pointer.
if (auto schema = subIGF.IGM.getOptions().PointerAuth.AsyncContextResume) {
Address fieldAddr =
fieldLayout.project(subIGF, calleeContext, /*offsets*/ llvm::None);
auto authInfo = PointerAuthInfo::emit(
subIGF, schema, fieldAddr.getAddress(), PointerAuthEntity());
fnVal = emitPointerAuthSign(subIGF, fnVal, authInfo);
}
fnVal = subIGF.Builder.CreateBitCast(
fnVal, subIGF.IGM.TaskContinuationFunctionPtrTy);
Explosion explosion;
explosion.add(fnVal);
saveValue(fieldLayout, explosion);
}
}
void gatherArgumentsFromApply() override {
super::gatherArgumentsFromApply(true);
Expand All @@ -1127,13 +1167,87 @@ class AsyncPartialApplicationForwarderEmission
// Nothing to do here. The error result pointer is already in the
// appropriate position.
}
FunctionPointer getFunctionPointerForDispatchCall(const FunctionPointer &fn) {
auto &IGM = subIGF.IGM;
// Strip off the return type. The original function pointer signature
// captured both the entry point type and the resume function type.
auto *fnTy = llvm::FunctionType::get(
IGM.VoidTy, fn.getSignature().getType()->params(), false /*vaargs*/);
auto signature =
Signature(fnTy, fn.getSignature().getAttributes(), IGM.SwiftAsyncCC);
auto fnPtr =
FunctionPointer(FunctionPointer::Kind::Function, fn.getRawPointer(),
fn.getAuthInfo(), signature);
return fnPtr;
}
llvm::CallInst *createCall(FunctionPointer &fnPtr) override {
return subIGF.Builder.CreateCall(fnPtr.getAsFunction(subIGF),
args.claimAll());
auto newFnPtr = FunctionPointer(
FunctionPointer::Kind::Function, fnPtr.getPointer(subIGF),
fnPtr.getAuthInfo(), Signature::forAsyncAwait(subIGF.IGM, origType));
auto &Builder = subIGF.Builder;

auto argValues = args.claimAll();

// Setup the suspend point.
SmallVector<llvm::Value *, 8> arguments;
auto signature = newFnPtr.getSignature();
auto asyncContextIndex = signature.getAsyncContextIndex();
auto paramAttributeFlags =
asyncContextIndex |
(signature.getAsyncResumeFunctionSwiftSelfIndex() << 8);
// Index of swiftasync context | ((index of swiftself) << 8).
arguments.push_back(
IGM.getInt32(paramAttributeFlags));
arguments.push_back(currentResumeFn);
auto resumeProjFn = subIGF.getOrCreateResumePrjFn();
arguments.push_back(
Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy));
auto dispatchFn = subIGF.createAsyncDispatchFn(
getFunctionPointerForDispatchCall(newFnPtr), argValues);
arguments.push_back(
Builder.CreateBitOrPointerCast(dispatchFn, IGM.Int8PtrTy));
arguments.push_back(
Builder.CreateBitOrPointerCast(newFnPtr.getRawPointer(), IGM.Int8PtrTy));
if (auto authInfo = newFnPtr.getAuthInfo()) {
arguments.push_back(newFnPtr.getAuthInfo().getDiscriminator());
}
for (auto arg : argValues)
arguments.push_back(arg);
auto resultTy =
cast<llvm::StructType>(signature.getType()->getReturnType());
return subIGF.emitSuspendAsyncCall(asyncContextIndex, resultTy, arguments);
}
void createReturn(llvm::CallInst *call) override {
call->setTailCallKind(IGM.AsyncTailCallKind);
subIGF.Builder.CreateRetVoid();
emitDeallocAsyncContext(subIGF, calleeContextBuffer);
auto numAsyncContextParams =
Signature::forAsyncReturn(IGM, outType).getAsyncContextIndex() + 1;
llvm::Value *result = call;
auto *suspendResultTy = cast<llvm::StructType>(result->getType());
Explosion resultExplosion;
Explosion errorExplosion;
SILFunctionConventions conv(outType, subIGF.getSILModule());
auto hasError = outType->hasErrorResult();

Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
SmallVector<llvm::Value *, 16> nativeResultsStorage;

if (suspendResultTy->getNumElements() == numAsyncContextParams) {
// no result to forward.
assert(!hasError);
} else {
auto &Builder = subIGF.Builder;
auto resultTys =
makeArrayRef(suspendResultTy->element_begin() + numAsyncContextParams,
suspendResultTy->element_end());

for (unsigned i = 0, e = resultTys.size(); i != e; ++i) {
llvm::Value *elt =
Builder.CreateExtractValue(result, numAsyncContextParams + i);
nativeResultsStorage.push_back(elt);
}
nativeResults = nativeResultsStorage;
}
emitAsyncReturn(subIGF, layout, origType, nativeResults);
}
void end() override {
assert(context.isValid());
Expand Down Expand Up @@ -1180,6 +1294,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
llvm::AttributeList outAttrs = outSig.getAttributes();
llvm::FunctionType *fwdTy = outSig.getType();
SILFunctionConventions outConv(outType, IGM.getSILModule());
Optional<AsyncContextLayout> asyncLayout;

StringRef FnName;
if (staticFnPtr)
Expand All @@ -1203,19 +1318,29 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,

IRGenFunction subIGF(IGM, fwd);
if (origType->isAsync()) {
subIGF.setupAsync(
Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex());
auto asyncContextIdx =
Signature::forAsyncEntry(IGM, outType).getAsyncContextIndex();
asyncLayout.emplace(irgen::getAsyncContextLayout(
IGM, origType, substType, subs, /*suppress generics*/ false,
FunctionPointer::Kind(
FunctionPointer::BasicKind::AsyncFunctionPointer)));

subIGF.setupAsync(asyncContextIdx);

auto *calleeAFP = staticFnPtr->getDirectPointer();
//auto *calleeAFP = staticFnPtr->getDirectPointer();
LinkEntity entity = LinkEntity::forPartialApplyForwarder(fwd);
auto size = Size(0);
assert(!asyncFunctionPtr &&
"already had an async function pointer to the forwarder?!");
asyncFunctionPtr = emitAsyncFunctionPointer(IGM, fwd, entity, size);
emitAsyncFunctionEntry(subIGF, *asyncLayout, entity, asyncContextIdx);
asyncFunctionPtr =
emitAsyncFunctionPointer(IGM, fwd, entity, asyncLayout->getSize());
// TODO: if calleeAFP is definition:
#if 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We will want to reinstate the old code conditionally later. This serves as a reminder for me how-to do it.

subIGF.Builder.CreateIntrinsicCall(
llvm::Intrinsic::coro_async_size_replace,
{subIGF.Builder.CreateBitCast(asyncFunctionPtr, IGM.Int8PtrTy),
subIGF.Builder.CreateBitCast(calleeAFP, IGM.Int8PtrTy)});
#endif
}
if (IGM.DebugInfo)
IGM.DebugInfo->emitArtificialFunction(subIGF, fwd);
Expand Down Expand Up @@ -1679,7 +1804,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,

llvm::CallInst *call = emission->createCall(fnPtr);

if (addressesToDeallocate.empty() && !needsAllocas &&
if (!origType->isAsync() && addressesToDeallocate.empty() && !needsAllocas &&
(!consumesContext || !dependsOnContextLifetime))
call->setTailCall();

Expand Down
9 changes: 9 additions & 0 deletions test/IRGen/async/partial_apply.sil
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,12 @@ bb0(%thick : $@callee_guaranteed @async @convention(thick) (Int64, Int32) -> Int
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s45indirect_guaranteed_captured_class_pair_paramTA.{{[0-9]+}}"(
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw_"(
// CHECK-LABEL: define internal swift{{(tail)?}}cc void @"$s12create_pa_f2Tw0_"(

sil @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error)

sil @dont_crash : $@convention(thin) @async (Int) -> @owned @async @callee_guaranteed (Int) -> (Int, @error Error) {
bb0(%0 : $Int):
%2 = function_ref @external_closure : $@convention(thin) @async (Int, Int) -> (Int, @error Error)
%3 = partial_apply [callee_guaranteed] %2(%0) : $@convention(thin) @async (Int, Int) -> (Int, @error Error)
return %3 : $@async @callee_guaranteed (Int) -> (Int, @error Error)
}