Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/swift/AST/ASTBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -2325,6 +2325,9 @@ BridgedBraceStmt BridgedBraceStmt_createImplicit(BridgedASTContext cContext,
BridgedASTNode element,
swift::SourceLoc rBLoc);

SWIFT_NAME("BridgedBraceStmt.hasAsyncNode(self:)")
bool BridgedBraceStmt_hasAsyncNode(BridgedBraceStmt braceStmt);

SWIFT_NAME("BridgedBreakStmt.createParsed(_:loc:targetName:targetLoc:)")
BridgedBreakStmt BridgedBreakStmt_createParsed(BridgedDeclContext cDeclContext,
swift::SourceLoc loc,
Expand Down Expand Up @@ -2356,6 +2359,10 @@ BridgedDeferStmt BridgedDeferStmt_createParsed(BridgedDeclContext cDeclContext,
SWIFT_NAME("getter:BridgedDeferStmt.tempDecl(self:)")
BridgedFuncDecl BridgedDeferStmt_getTempDecl(BridgedDeferStmt bridged);

SWIFT_NAME("BridgedDeferStmt.makeAsync(self:_:)")
void BridgedDeferStmt_makeAsync(BridgedDeferStmt bridged,
BridgedASTContext ctx);

SWIFT_NAME("BridgedDiscardStmt.createParsed(_:discardLoc:subExpr:)")
BridgedDiscardStmt BridgedDiscardStmt_createParsed(BridgedASTContext cContext,
swift::SourceLoc discardLoc,
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7904,6 +7904,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// type of the function will be `async` as well.
bool hasAsync() const { return Bits.AbstractFunctionDecl.Async; }

void setHasAsync(bool async) { Bits.AbstractFunctionDecl.Async = async; }

/// Determine whether the given function is concurrent.
///
/// A function is concurrent if it has the @Sendable attribute.
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -5514,6 +5514,12 @@ ERROR(tryless_throwing_call_in_nonexhaustive_catch,none,
"the enclosing catch is not exhaustive", (StringRef))
ERROR(throw_in_nonexhaustive_catch,none,
"error is not handled because the enclosing catch is not exhaustive", ())
ERROR(throw_in_defer_body,none,
"errors cannot be thrown out of a defer body", ())
ERROR(throwing_op_in_defer_body,none,
"%0 can throw, but errors cannot be thrown out of a defer body", (StringRef))
ERROR(async_defer_in_non_async_context,none,
"'async' defer must appear within an 'async' context", ())

#define EFFECTS_CONTEXT_KIND \
"%select{<<ERROR>>|" \
Expand Down
6 changes: 5 additions & 1 deletion include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ class DeferStmt : public Stmt {

/// Dig the original user's body of the defer out for AST fidelity.
BraceStmt *getBodyAsWritten() const;


/// Turn this into an async defer by modifying the temp decl and call expr
/// appropriately.
void makeAsync(ASTContext &ctx);

static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Defer; }
};

Expand Down
9 changes: 9 additions & 0 deletions lib/AST/Bridging/StmtBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ BridgedBraceStmt BridgedBraceStmt_createImplicit(BridgedASTContext cContext,
/*Implicit=*/true);
}

bool BridgedBraceStmt_hasAsyncNode(BridgedBraceStmt braceStmt) {
return (bool)braceStmt.unbridged()->findAsyncNode();
}

BridgedBreakStmt BridgedBreakStmt_createParsed(BridgedDeclContext cDeclContext,
SourceLoc loc,
Identifier targetName,
Expand Down Expand Up @@ -156,6 +160,11 @@ BridgedFuncDecl BridgedDeferStmt_getTempDecl(BridgedDeferStmt bridged) {
return bridged.unbridged()->getTempDecl();
}

void BridgedDeferStmt_makeAsync(BridgedDeferStmt bridged,
BridgedASTContext ctx) {
return bridged.unbridged()->makeAsync(ctx.unbridged());
}

BridgedDiscardStmt BridgedDiscardStmt_createParsed(BridgedASTContext cContext,
SourceLoc discardLoc,
BridgedExpr cSubExpr) {
Expand Down
18 changes: 17 additions & 1 deletion lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,21 @@ ASTNode BraceStmt::findAsyncNode() {
}

PreWalkAction walkToDeclPre(Decl *decl) override {
// Do not walk into function or type declarations.
// Do not walk into function or type declarations (except for defer
// bodies).
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
if (patternBinding->isAsyncLet())
AsyncNode = patternBinding;

return Action::Continue();
}

if (auto *fnDecl = dyn_cast<FuncDecl>(decl)) {
if (fnDecl->isDeferBody()) {
return Action::Continue();
}
}

return Action::SkipNode();
}

Expand Down Expand Up @@ -394,6 +401,15 @@ BraceStmt *DeferStmt::getBodyAsWritten() const {
return tempDecl->getBody();
}

void DeferStmt::makeAsync(ASTContext &ctx) {
tempDecl->setHasAsync(true);
setCallExpr(AwaitExpr::createImplicit(ctx, SourceLoc(), getCallExpr()));
auto *attr = new (ctx) NonisolatedAttr(SourceLoc(), SourceRange(),
NonIsolatedModifier::NonSending,
/*implicit*/ true);
tempDecl->getAttrs().add(attr);
}

bool LabeledStmt::isPossibleContinueTarget() const {
switch (getKind()) {
#define LABELED_STMT(ID, PARENT)
Expand Down
6 changes: 5 additions & 1 deletion lib/ASTGen/Sources/ASTGen/Stmts.swift
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ extension ASTGenVisitor {
deferLoc: deferLoc
)
self.withDeclContext(stmt.tempDecl.asDeclContext) {
stmt.tempDecl.setParsedBody(self.generate(codeBlock: node.body))
let body = self.generate(codeBlock: node.body)
stmt.tempDecl.setParsedBody(body)
if body.hasAsyncNode() {
stmt.makeAsync(ctx)
}
}
return stmt
}
Expand Down
4 changes: 4 additions & 0 deletions lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,10 @@ ParserResult<Stmt> Parser::parseStmtDefer() {
return nullptr;
Status |= Body;

if (bool(Body.get()->findAsyncNode())) {
DS->makeAsync(Context);
}

// Clone the current hasher and extract a Fingerprint.
StableHasher currentHash{*CurrentTokenHash};
Fingerprint fp(std::move(currentHash));
Expand Down
57 changes: 42 additions & 15 deletions lib/Sema/TypeCheckEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2854,9 +2854,6 @@ class Context {

/// The guard expression controlling a catch.
CatchGuard,

/// A defer body
DeferBody,
};

private:
Expand Down Expand Up @@ -2973,6 +2970,21 @@ class Context {
return isa<AutoClosureExpr>(closure);
}

bool isDeferBody() const {
if (!Function)
return false;

if (ErrorHandlingIgnoresFunction)
return false;

auto fnDecl =
dyn_cast_or_null<FuncDecl>(Function->getAbstractFunctionDecl());
if (!fnDecl)
return false;

return fnDecl->isDeferBody();
}

static Context forTopLevelCode(TopLevelCodeDecl *D) {
// Top-level code implicitly handles errors.
return Context(/*handlesErrors=*/true,
Expand All @@ -2999,10 +3011,6 @@ class Context {
return Context(D->hasThrows(), D->isAsyncContext(), AnyFunctionRef(D), D);
}

static Context forDeferBody(DeclContext *dc) {
return Context(Kind::DeferBody, dc);
}

static Context forInitializer(Initializer *init) {
if (isa<DefaultArgumentInitializer>(init)) {
return Context(Kind::DefaultArgument, init);
Expand Down Expand Up @@ -3283,6 +3291,12 @@ class Context {
return;
}

if (isDeferBody()) {
Diags.diagnose(E.getStartLoc(), diag::throwing_op_in_defer_body,
getEffectSourceName(reason));
return;
}

if (hasPolymorphicEffect(EffectKind::Throws)) {
diagnoseThrowInLegalContext(Diags, E, isTryCovered, reason,
diag::throwing_call_in_rethrows_function,
Expand All @@ -3303,7 +3317,6 @@ class Context {
case Kind::PropertyWrapper:
case Kind::CatchPattern:
case Kind::CatchGuard:
case Kind::DeferBody:
Diags.diagnose(E.getStartLoc(), diag::throwing_op_in_illegal_context,
static_cast<unsigned>(getKind()), getEffectSourceName(reason));
return;
Expand All @@ -3324,6 +3337,11 @@ class Context {
return;
}

if (isDeferBody()) {
Diags.diagnose(S->getStartLoc(), diag::throw_in_defer_body);
return;
}

if (hasPolymorphicEffect(EffectKind::Throws)) {
Diags.diagnose(S->getStartLoc(), diag::throw_in_rethrows_function);
return;
Expand All @@ -3340,7 +3358,6 @@ class Context {
case Kind::PropertyWrapper:
case Kind::CatchPattern:
case Kind::CatchGuard:
case Kind::DeferBody:
Diags.diagnose(S->getStartLoc(), diag::throw_in_illegal_context,
static_cast<unsigned>(getKind()));
return;
Expand All @@ -3367,7 +3384,6 @@ class Context {
case Kind::PropertyWrapper:
case Kind::CatchPattern:
case Kind::CatchGuard:
case Kind::DeferBody:
assert(!DiagnoseErrorOnTry);
// Diagnosed at the call sites.
return;
Expand Down Expand Up @@ -3448,6 +3464,21 @@ class Context {
void diagnoseUnhandledAsyncSite(DiagnosticEngine &Diags, ASTNode node,
std::optional<PotentialEffectReason> maybeReason,
bool forAwait = false) {

// If this is an apply of a defer body, emit a special diagnostic. We must
// check this before we check `isImplicit` below since these are always
// implicit!
if (auto *applyExpr = dyn_cast_or_null<ApplyExpr>(node.dyn_cast<Expr*>())) {
auto *calledDecl = applyExpr->getCalledValue(/*skipConversions=*/true);
if (auto *fnDecl = dyn_cast_or_null<FuncDecl>(calledDecl)) {
if (fnDecl->isDeferBody()) {
Diags.diagnose(fnDecl->getStartLoc(),
diag::async_defer_in_non_async_context);
return;
}
}
}

if (node.isImplicit()) {
// The reason we return early on implicit nodes is that sometimes we
// inject implicit closures, e.g. in 'async let' and we'd end up
Expand Down Expand Up @@ -3477,7 +3508,6 @@ class Context {
case Kind::PropertyWrapper:
case Kind::CatchPattern:
case Kind::CatchGuard:
case Kind::DeferBody:
diagnoseAsyncInIllegalContext(Diags, node);
return;
}
Expand Down Expand Up @@ -4639,7 +4669,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
ContextScope scope(*this, std::nullopt);
scope.enterUnsafe(S->getDeferLoc());

// Walk the call expression. We don't care about the rest.
S->getCallExpr()->walk(*this);

return ShouldNotRecurse;
Expand Down Expand Up @@ -5004,9 +5033,7 @@ void TypeChecker::checkFunctionEffects(AbstractFunctionDecl *fn) {
PrettyStackTraceDecl debugStack("checking effects handling for", fn);
#endif

auto isDeferBody = isa<FuncDecl>(fn) && cast<FuncDecl>(fn)->isDeferBody();
auto context =
isDeferBody ? Context::forDeferBody(fn) : Context::forFunction(fn);
auto context = Context::forFunction(fn);
auto &ctx = fn->getASTContext();
CheckEffectsCoverage checker(ctx, context);

Expand Down
11 changes: 4 additions & 7 deletions test/expr/unary/async_await.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ func getInt() async -> Int { return 5 }
func test2(
defaulted: Int = await getInt() // expected-error{{'async' call cannot occur in a default argument}}
) async {
defer {
_ = await getInt() // expected-error{{'async' call cannot occur in a defer body}}
}
print("foo")
}

Expand Down Expand Up @@ -182,10 +179,10 @@ func testAsyncLet() async throws {
}

defer {
async let deferX: Int = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}}
_ = await deferX // expected-error {{async let 'deferX' cannot be referenced in a defer body}}
async let _: Int = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}}
async let _ = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}}
async let deferX: Int = await getInt()
_ = await deferX
async let _: Int = await getInt()
async let _ = await getInt()
}

async let x1 = getIntUnsafely() // okay, try is implicit here
Expand Down
24 changes: 24 additions & 0 deletions test/stmt/defer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,27 @@ func badForwardReference() {

let z2 = 0
}


public func basicAsyncDefer() async {
defer { await asyncFunc() }
voidFunc()
}

func asyncFunc() async {}
func voidFunc() {}

func testClosure() async {
let f = {
defer { await asyncFunc() }
voidFunc()
}

await f()
}

func asyncDeferInSyncFunc() {
defer { await asyncFunc() } // expected-error {{'async' defer must appear within an 'async' context}}
voidFunc()
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what behavior do we get if we have no source-level 'await' in the defer?

func f() {
  defer { async let _: () = voidFunc() }
  voidFunc()
}