diff --git a/include/swift/AST/ASTBridging.h b/include/swift/AST/ASTBridging.h index 99f9d05e0eef3..b5db504efac37 100644 --- a/include/swift/AST/ASTBridging.h +++ b/include/swift/AST/ASTBridging.h @@ -2325,6 +2325,9 @@ BridgedBraceStmt BridgedBraceStmt_createImplicit(BridgedASTContext cContext, BridgedASTNode element, swift::SourceLoc rBLoc); +SWIFT_NAME("BridgedBraceStmt.hasAsyncNode(self:)") +bool BridgedBraceStmt_hasAsyncNode(BridgedBraceStmt braceStmt); + SWIFT_NAME("BridgedBreakStmt.createParsed(_:loc:targetName:targetLoc:)") BridgedBreakStmt BridgedBreakStmt_createParsed(BridgedDeclContext cDeclContext, swift::SourceLoc loc, @@ -2356,6 +2359,10 @@ BridgedDeferStmt BridgedDeferStmt_createParsed(BridgedDeclContext cDeclContext, SWIFT_NAME("getter:BridgedDeferStmt.tempDecl(self:)") BridgedFuncDecl BridgedDeferStmt_getTempDecl(BridgedDeferStmt bridged); +SWIFT_NAME("BridgedDeferStmt.makeAsync(self:_:)") +void BridgedDeferStmt_makeAsync(BridgedDeferStmt bridged, + BridgedASTContext ctx); + SWIFT_NAME("BridgedDiscardStmt.createParsed(_:discardLoc:subExpr:)") BridgedDiscardStmt BridgedDiscardStmt_createParsed(BridgedASTContext cContext, swift::SourceLoc discardLoc, diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 51376bbceaec4..ba20291156d3a 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -7904,6 +7904,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { /// type of the function will be `async` as well. bool hasAsync() const { return Bits.AbstractFunctionDecl.Async; } + void setHasAsync(bool async) { Bits.AbstractFunctionDecl.Async = async; } + /// Determine whether the given function is concurrent. /// /// A function is concurrent if it has the @Sendable attribute. diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index d7ffcbd2cdd0c..d508ef6e9acd9 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -5514,6 +5514,12 @@ ERROR(tryless_throwing_call_in_nonexhaustive_catch,none, "the enclosing catch is not exhaustive", (StringRef)) ERROR(throw_in_nonexhaustive_catch,none, "error is not handled because the enclosing catch is not exhaustive", ()) +ERROR(throw_in_defer_body,none, + "errors cannot be thrown out of a defer body", ()) +ERROR(throwing_op_in_defer_body,none, + "%0 can throw, but errors cannot be thrown out of a defer body", (StringRef)) +ERROR(async_defer_in_non_async_context,none, + "'async' defer must appear within an 'async' context", ()) #define EFFECTS_CONTEXT_KIND \ "%select{<>|" \ diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index c8f9d4bb3e76d..a09e11dfc9efa 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -424,7 +424,11 @@ class DeferStmt : public Stmt { /// Dig the original user's body of the defer out for AST fidelity. BraceStmt *getBodyAsWritten() const; - + + /// Turn this into an async defer by modifying the temp decl and call expr + /// appropriately. + void makeAsync(ASTContext &ctx); + static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Defer; } }; diff --git a/lib/AST/Bridging/StmtBridging.cpp b/lib/AST/Bridging/StmtBridging.cpp index e8c11bb98d0cf..431e6a20eb2ff 100644 --- a/lib/AST/Bridging/StmtBridging.cpp +++ b/lib/AST/Bridging/StmtBridging.cpp @@ -96,6 +96,10 @@ BridgedBraceStmt BridgedBraceStmt_createImplicit(BridgedASTContext cContext, /*Implicit=*/true); } +bool BridgedBraceStmt_hasAsyncNode(BridgedBraceStmt braceStmt) { + return (bool)braceStmt.unbridged()->findAsyncNode(); +} + BridgedBreakStmt BridgedBreakStmt_createParsed(BridgedDeclContext cDeclContext, SourceLoc loc, Identifier targetName, @@ -156,6 +160,11 @@ BridgedFuncDecl BridgedDeferStmt_getTempDecl(BridgedDeferStmt bridged) { return bridged.unbridged()->getTempDecl(); } +void BridgedDeferStmt_makeAsync(BridgedDeferStmt bridged, + BridgedASTContext ctx) { + return bridged.unbridged()->makeAsync(ctx.unbridged()); +} + BridgedDiscardStmt BridgedDiscardStmt_createParsed(BridgedASTContext cContext, SourceLoc discardLoc, BridgedExpr cSubExpr) { diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index f1460e2383ed6..0535babe7582c 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -270,7 +270,8 @@ ASTNode BraceStmt::findAsyncNode() { } PreWalkAction walkToDeclPre(Decl *decl) override { - // Do not walk into function or type declarations. + // Do not walk into function or type declarations (except for defer + // bodies). if (auto *patternBinding = dyn_cast(decl)) { if (patternBinding->isAsyncLet()) AsyncNode = patternBinding; @@ -278,6 +279,12 @@ ASTNode BraceStmt::findAsyncNode() { return Action::Continue(); } + if (auto *fnDecl = dyn_cast(decl)) { + if (fnDecl->isDeferBody()) { + return Action::Continue(); + } + } + return Action::SkipNode(); } @@ -394,6 +401,15 @@ BraceStmt *DeferStmt::getBodyAsWritten() const { return tempDecl->getBody(); } +void DeferStmt::makeAsync(ASTContext &ctx) { + tempDecl->setHasAsync(true); + setCallExpr(AwaitExpr::createImplicit(ctx, SourceLoc(), getCallExpr())); + auto *attr = new (ctx) NonisolatedAttr(SourceLoc(), SourceRange(), + NonIsolatedModifier::NonSending, + /*implicit*/ true); + tempDecl->getAttrs().add(attr); +} + bool LabeledStmt::isPossibleContinueTarget() const { switch (getKind()) { #define LABELED_STMT(ID, PARENT) diff --git a/lib/ASTGen/Sources/ASTGen/Stmts.swift b/lib/ASTGen/Sources/ASTGen/Stmts.swift index c97d8d558e3c4..181cacce0e079 100644 --- a/lib/ASTGen/Sources/ASTGen/Stmts.swift +++ b/lib/ASTGen/Sources/ASTGen/Stmts.swift @@ -277,7 +277,11 @@ extension ASTGenVisitor { deferLoc: deferLoc ) self.withDeclContext(stmt.tempDecl.asDeclContext) { - stmt.tempDecl.setParsedBody(self.generate(codeBlock: node.body)) + let body = self.generate(codeBlock: node.body) + stmt.tempDecl.setParsedBody(body) + if body.hasAsyncNode() { + stmt.makeAsync(ctx) + } } return stmt } diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index 35208bdca0393..9990491338086 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -1097,6 +1097,10 @@ ParserResult Parser::parseStmtDefer() { return nullptr; Status |= Body; + if (bool(Body.get()->findAsyncNode())) { + DS->makeAsync(Context); + } + // Clone the current hasher and extract a Fingerprint. StableHasher currentHash{*CurrentTokenHash}; Fingerprint fp(std::move(currentHash)); diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index c23bd67e24db1..82e798fa00502 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -2854,9 +2854,6 @@ class Context { /// The guard expression controlling a catch. CatchGuard, - - /// A defer body - DeferBody, }; private: @@ -2973,6 +2970,21 @@ class Context { return isa(closure); } + bool isDeferBody() const { + if (!Function) + return false; + + if (ErrorHandlingIgnoresFunction) + return false; + + auto fnDecl = + dyn_cast_or_null(Function->getAbstractFunctionDecl()); + if (!fnDecl) + return false; + + return fnDecl->isDeferBody(); + } + static Context forTopLevelCode(TopLevelCodeDecl *D) { // Top-level code implicitly handles errors. return Context(/*handlesErrors=*/true, @@ -2999,10 +3011,6 @@ class Context { return Context(D->hasThrows(), D->isAsyncContext(), AnyFunctionRef(D), D); } - static Context forDeferBody(DeclContext *dc) { - return Context(Kind::DeferBody, dc); - } - static Context forInitializer(Initializer *init) { if (isa(init)) { return Context(Kind::DefaultArgument, init); @@ -3283,6 +3291,12 @@ class Context { return; } + if (isDeferBody()) { + Diags.diagnose(E.getStartLoc(), diag::throwing_op_in_defer_body, + getEffectSourceName(reason)); + return; + } + if (hasPolymorphicEffect(EffectKind::Throws)) { diagnoseThrowInLegalContext(Diags, E, isTryCovered, reason, diag::throwing_call_in_rethrows_function, @@ -3303,7 +3317,6 @@ class Context { case Kind::PropertyWrapper: case Kind::CatchPattern: case Kind::CatchGuard: - case Kind::DeferBody: Diags.diagnose(E.getStartLoc(), diag::throwing_op_in_illegal_context, static_cast(getKind()), getEffectSourceName(reason)); return; @@ -3324,6 +3337,11 @@ class Context { return; } + if (isDeferBody()) { + Diags.diagnose(S->getStartLoc(), diag::throw_in_defer_body); + return; + } + if (hasPolymorphicEffect(EffectKind::Throws)) { Diags.diagnose(S->getStartLoc(), diag::throw_in_rethrows_function); return; @@ -3340,7 +3358,6 @@ class Context { case Kind::PropertyWrapper: case Kind::CatchPattern: case Kind::CatchGuard: - case Kind::DeferBody: Diags.diagnose(S->getStartLoc(), diag::throw_in_illegal_context, static_cast(getKind())); return; @@ -3367,7 +3384,6 @@ class Context { case Kind::PropertyWrapper: case Kind::CatchPattern: case Kind::CatchGuard: - case Kind::DeferBody: assert(!DiagnoseErrorOnTry); // Diagnosed at the call sites. return; @@ -3448,6 +3464,21 @@ class Context { void diagnoseUnhandledAsyncSite(DiagnosticEngine &Diags, ASTNode node, std::optional maybeReason, bool forAwait = false) { + + // If this is an apply of a defer body, emit a special diagnostic. We must + // check this before we check `isImplicit` below since these are always + // implicit! + if (auto *applyExpr = dyn_cast_or_null(node.dyn_cast())) { + auto *calledDecl = applyExpr->getCalledValue(/*skipConversions=*/true); + if (auto *fnDecl = dyn_cast_or_null(calledDecl)) { + if (fnDecl->isDeferBody()) { + Diags.diagnose(fnDecl->getStartLoc(), + diag::async_defer_in_non_async_context); + return; + } + } + } + if (node.isImplicit()) { // The reason we return early on implicit nodes is that sometimes we // inject implicit closures, e.g. in 'async let' and we'd end up @@ -3477,7 +3508,6 @@ class Context { case Kind::PropertyWrapper: case Kind::CatchPattern: case Kind::CatchGuard: - case Kind::DeferBody: diagnoseAsyncInIllegalContext(Diags, node); return; } @@ -4639,7 +4669,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker ContextScope scope(*this, std::nullopt); scope.enterUnsafe(S->getDeferLoc()); - // Walk the call expression. We don't care about the rest. S->getCallExpr()->walk(*this); return ShouldNotRecurse; @@ -5004,9 +5033,7 @@ void TypeChecker::checkFunctionEffects(AbstractFunctionDecl *fn) { PrettyStackTraceDecl debugStack("checking effects handling for", fn); #endif - auto isDeferBody = isa(fn) && cast(fn)->isDeferBody(); - auto context = - isDeferBody ? Context::forDeferBody(fn) : Context::forFunction(fn); + auto context = Context::forFunction(fn); auto &ctx = fn->getASTContext(); CheckEffectsCoverage checker(ctx, context); diff --git a/test/expr/unary/async_await.swift b/test/expr/unary/async_await.swift index 98bc9ba5ffe8d..571cc4172fa0d 100644 --- a/test/expr/unary/async_await.swift +++ b/test/expr/unary/async_await.swift @@ -18,9 +18,6 @@ func getInt() async -> Int { return 5 } func test2( defaulted: Int = await getInt() // expected-error{{'async' call cannot occur in a default argument}} ) async { - defer { - _ = await getInt() // expected-error{{'async' call cannot occur in a defer body}} - } print("foo") } @@ -182,10 +179,10 @@ func testAsyncLet() async throws { } defer { - async let deferX: Int = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}} - _ = await deferX // expected-error {{async let 'deferX' cannot be referenced in a defer body}} - async let _: Int = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}} - async let _ = await getInt() // expected-error {{'async let' cannot be used on declarations in a defer body}} + async let deferX: Int = await getInt() + _ = await deferX + async let _: Int = await getInt() + async let _ = await getInt() } async let x1 = getIntUnsafely() // okay, try is implicit here diff --git a/test/stmt/defer.swift b/test/stmt/defer.swift index ac113a4b1baaa..45e08a26c1f67 100644 --- a/test/stmt/defer.swift +++ b/test/stmt/defer.swift @@ -161,3 +161,27 @@ func badForwardReference() { let z2 = 0 } + + +public func basicAsyncDefer() async { + defer { await asyncFunc() } + voidFunc() +} + +func asyncFunc() async {} +func voidFunc() {} + +func testClosure() async { + let f = { + defer { await asyncFunc() } + voidFunc() + } + + await f() +} + +func asyncDeferInSyncFunc() { + defer { await asyncFunc() } // expected-error {{'async' defer must appear within an 'async' context}} + voidFunc() +} +