diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp index 0865bde24e1ca..fe29c375ea528 100644 --- a/lib/Sema/TypeCheckConcurrency.cpp +++ b/lib/Sema/TypeCheckConcurrency.cpp @@ -535,8 +535,14 @@ ActorIsolationRestriction ActorIsolationRestriction::forDeclaration( // Local captures can only be referenced in their local context or a // context that is guaranteed not to run concurrently with it. - if (cast(decl)->isLocalCapture()) + if (cast(decl)->isLocalCapture()) { + // Local functions are safe to capture; their bodies are checked based on + // where that capture is used. + if (isa(decl)) + return forUnrestricted(); + return forLocalCapture(decl->getDeclContext()); + } // Determine the actor isolation of the given declaration. switch (auto isolation = getActorIsolation(cast(decl))) { @@ -625,11 +631,53 @@ static bool isAsyncCall(const ApplyExpr *call) { Type funcTypeType = call->getFn()->getType(); if (!funcTypeType) return false; - FunctionType *funcType = funcTypeType->castTo(); + AnyFunctionType *funcType = funcTypeType->getAs(); + if (!funcType) + return false; return funcType->isAsync(); } +/// Determine whether we should diagnose data races within the current context. +/// +/// By default, we do this only in code that makes use of concurrency +/// features. +static bool shouldDiagnoseExistingDataRaces(const DeclContext *dc); + +/// Determine whether this closure is escaping. +static bool isEscapingClosure(const AbstractClosureExpr *closure) { + if (auto type = closure->getType()) { + if (auto fnType = type->getAs()) + return !fnType->isNoEscape(); + } + + return true; +} + namespace { + /// Check whether a particular context may execute concurrently within + /// another context. + class ConcurrentExecutionChecker { + /// Keeps track of the first location at which a given local function is + /// referenced from a context that may execute concurrently with the + /// context in which it was introduced. + llvm::SmallDenseMap concurrentRefs; + + public: + /// Determine whether (and where) a given local function is referenced + /// from a context that may execute concurrently with the context in + /// which it is declared. + /// + /// \returns the source location of the first reference to the local + /// function that may be concurrent. If the result is an invalid + /// source location, there are no such references. + SourceLoc getConcurrentReferenceLoc(const FuncDecl *localFunc); + + /// Determine whether code in the given use context might execute + /// concurrently with code in the definition context. + bool mayExecuteConcurrentlyWith( + const DeclContext *useContext, const DeclContext *defContext); + }; + /// Check for adherence to the actor isolation rules, emitting errors /// when actor-isolated declarations are used in an unsafe manner. class ActorIsolationChecker : public ASTWalker { @@ -637,10 +685,20 @@ namespace { SmallVector contextStack; SmallVector applyStack; + ConcurrentExecutionChecker concurrentExecutionChecker; + const DeclContext *getDeclContext() const { return contextStack.back(); } + /// Determine whether code in the given use context might execute + /// concurrently with code in the definition context. + bool mayExecuteConcurrentlyWith( + const DeclContext *useContext, const DeclContext *defContext) { + return concurrentExecutionChecker.mayExecuteConcurrentlyWith( + useContext, defContext); + } + public: ActorIsolationChecker(const DeclContext *dc) : ctx(dc->getASTContext()) { contextStack.push_back(dc); @@ -682,10 +740,19 @@ namespace { bool shouldWalkIntoTapExpression() override { return false; } - bool walkToDeclPre(Decl *D) override { - // Don't walk into functions; they'll be handled separately. - if (isa(D)) - return false; + bool walkToDeclPre(Decl *decl) override { + if (auto func = dyn_cast(decl)) { + contextStack.push_back(func); + } + + return true; + } + + bool walkToDeclPost(Decl *decl) override { + if (auto func = dyn_cast(decl)) { + assert(contextStack.back() == func); + contextStack.pop_back(); + } return true; } @@ -828,35 +895,6 @@ namespace { } } - /// Determine whether code in the given use context might execute - /// concurrently with code in the definition context. - bool mayExecuteConcurrentlyWith( - const DeclContext *useContext, const DeclContext *defContext) { - - // Walk the context chain from the use to the definition. - while (useContext != defContext) { - // If we find an escaping closure, it can be run concurrently. - if (auto closure = dyn_cast(useContext)) { - if (isEscapingClosure(closure)) - return true; - } - - // If we find a local function, it can escape and be run concurrently. - if (auto func = dyn_cast(useContext)) { - if (func->isLocalCapture()) - return true; - } - - // If we hit a module-scope context, it's not concurrent. - useContext = useContext->getParent(); - if (useContext->isModuleScopeContext()) - return false; - } - - // We hit the same context, so it won't execute concurrently. - return false; - } - // Retrieve the nearest enclosing actor context. static ClassDecl *getNearestEnclosingActorContext(const DeclContext *dc) { while (!dc->isModuleScopeContext()) { @@ -876,15 +914,16 @@ namespace { /// Diagnose a reference to an unsafe entity. /// /// \returns true if we diagnosed the entity, \c false otherwise. - bool diagnoseReferenceToUnsafe(ValueDecl *value, SourceLoc loc) { - // Only diagnose unsafe concurrent accesses within the context of an - // actor. This is globally unsafe, but locally enforceable. - if (!getNearestEnclosingActorContext(getDeclContext())) + bool diagnoseReferenceToUnsafeGlobal(ValueDecl *value, SourceLoc loc) { + if (!shouldDiagnoseExistingDataRaces(getDeclContext())) + return false; + + // Only diagnose direct references to mutable global state. + auto var = dyn_cast(value); + if (!var || var->isLet()) return false; - // Only diagnose direct references to mutable shared state. This is - // globally unsafe, but reduces the noise. - if (!isa(value) || !cast(value)->hasStorage()) + if (!var->getDeclContext()->isModuleScopeContext() && !var->isStatic()) return false; ctx.Diags.diagnose( @@ -1170,9 +1209,7 @@ namespace { value, loc, isolation.getGlobalActor()); case ActorIsolationRestriction::LocalCapture: - // Only diagnose unsafe concurrent accesses within the context of an - // actor. This is globally unsafe, but locally enforceable. - if (!getNearestEnclosingActorContext(getDeclContext())) + if (!shouldDiagnoseExistingDataRaces(getDeclContext())) return false; // Check whether we are in a context that will not execute concurrently @@ -1190,7 +1227,7 @@ namespace { return false; case ActorIsolationRestriction::Unsafe: - return diagnoseReferenceToUnsafe(value, loc); + return diagnoseReferenceToUnsafeGlobal(value, loc); } llvm_unreachable("unhandled actor isolation kind!"); } @@ -1303,16 +1340,6 @@ namespace { llvm_unreachable("unhandled actor isolation kind!"); } - /// Determine whether this closure is escaping. - static bool isEscapingClosure(const AbstractClosureExpr *closure) { - if (auto type = closure->getType()) { - if (auto fnType = type->getAs()) - return !fnType->isNoEscape(); - } - - return true; - } - /// Determine the isolation of a particular closure. /// /// This function assumes that enclosing closures have already had their @@ -1410,6 +1437,138 @@ namespace { }; } +SourceLoc ConcurrentExecutionChecker::getConcurrentReferenceLoc( + const FuncDecl *localFunc) { + + // If we've already computed a result, we're done. + auto known = concurrentRefs.find(localFunc); + if (known != concurrentRefs.end()) + return known->second; + + // Record that there are no concurrent references to this local function. This + // prevents infinite recursion if two local functions call each other. + concurrentRefs[localFunc] = SourceLoc(); + + class ConcurrentLocalRefWalker : public ASTWalker { + ConcurrentExecutionChecker &checker; + const FuncDecl *targetFunc; + SmallVector contextStack; + + const DeclContext *getDeclContext() const { + return contextStack.back(); + } + + public: + ConcurrentLocalRefWalker( + ConcurrentExecutionChecker &checker, const FuncDecl *targetFunc + ) : checker(checker), targetFunc(targetFunc) { + contextStack.push_back(targetFunc->getDeclContext()); + } + + std::pair walkToExprPre(Expr *expr) override { + if (auto *closure = dyn_cast(expr)) { + contextStack.push_back(closure); + return { true, expr }; + } + + if (auto *declRef = dyn_cast(expr)) { + // If this is a reference to the target function from a context + // that may execute concurrently with the context where the target + // function was declared, record the location. + if (declRef->getDecl() == targetFunc && + checker.mayExecuteConcurrentlyWith( + getDeclContext(), contextStack.front())) { + SourceLoc &loc = checker.concurrentRefs[targetFunc]; + if (loc.isInvalid()) + loc = declRef->getLoc(); + + return { false, expr }; + } + + return { true, expr }; + } + + return { true, expr }; + } + + Expr *walkToExprPost(Expr *expr) override { + if (auto *closure = dyn_cast(expr)) { + assert(contextStack.back() == closure); + contextStack.pop_back(); + } + + return expr; + } + + bool walkToDeclPre(Decl *decl) override { + if (isa(decl) || isa(decl)) + return false; + + if (auto func = dyn_cast(decl)) { + contextStack.push_back(func); + } + + return true; + } + + bool walkToDeclPost(Decl *decl) override { + if (auto func = dyn_cast(decl)) { + assert(contextStack.back() == func); + contextStack.pop_back(); + } + + return true; + } + }; + + // Walk the body of the enclosing function, where all references to the + // given local function would occur. + Stmt *enclosingBody = nullptr; + DeclContext *enclosingDC = localFunc->getDeclContext(); + if (auto enclosingFunc = dyn_cast(enclosingDC)) + enclosingBody = enclosingFunc->getBody(); + else if (auto enclosingClosure = dyn_cast(enclosingDC)) + enclosingBody = enclosingClosure->getBody(); + + assert(enclosingBody && "Cannot have a local function here"); + ConcurrentLocalRefWalker walker(*this, localFunc); + enclosingBody->walk(walker); + + return concurrentRefs[localFunc]; +} + +bool ConcurrentExecutionChecker::mayExecuteConcurrentlyWith( + const DeclContext *useContext, const DeclContext *defContext) { + // Walk the context chain from the use to the definition. + while (useContext != defContext) { + // If we find an escaping closure, it can be run concurrently. + if (auto closure = dyn_cast(useContext)) { + if (isEscapingClosure(closure)) + return true; + } + + // If we find a local function that was referenced in code that can be + // executed concurrently with where the local function was declared, the + // local function can be run concurrently. + if (auto func = dyn_cast(useContext)) { + if (func->isLocalCapture()) { + SourceLoc concurrentLoc = getConcurrentReferenceLoc(func); + if (concurrentLoc.isValid()) + return true; + } + } + + // If we hit a module-scope or type context context, it's not + // concurrent. + useContext = useContext->getParent(); + if (useContext->isModuleScopeContext() || useContext->isTypeContext()) + return false; + } + + // We hit the same context, so it won't execute concurrently. + return false; +} + void swift::checkTopLevelActorIsolation(TopLevelCodeDecl *decl) { ActorIsolationChecker checker(decl); decl->getBody()->walk(checker); @@ -1447,7 +1606,8 @@ void swift::checkPropertyWrapperActorIsolation( /// \returns the actor isolation determined from attributes alone (with no /// inference rules). Returns \c None if there were no attributes on this /// declaration. -static Optional getIsolationFromAttributes(Decl *decl) { +static Optional getIsolationFromAttributes( + const Decl *decl, bool shouldDiagnose = true) { // Look up attributes on the declaration that can affect its actor isolation. // If any of them are present, use that attribute. auto independentAttr = decl->getAttrs().getAttribute(); @@ -1467,12 +1627,14 @@ static Optional getIsolationFromAttributes(Decl *decl) { name = selfTypeDecl->getName(); } - decl->diagnose( - diag::actor_isolation_multiple_attr, decl->getDescriptiveKind(), - name, independentAttr->getAttrName(), - globalActorAttr->second->getName().str()) - .highlight(independentAttr->getRangeWithAt()) - .highlight(globalActorAttr->first->getRangeWithAt()); + if (shouldDiagnose) { + decl->diagnose( + diag::actor_isolation_multiple_attr, decl->getDescriptiveKind(), + name, independentAttr->getAttrName(), + globalActorAttr->second->getName().str()) + .highlight(independentAttr->getRangeWithAt()) + .highlight(globalActorAttr->first->getRangeWithAt()); + } } // If the declaration is explicitly marked @actorIndependent, report it as @@ -1751,3 +1913,40 @@ void swift::checkOverrideActorIsolation(ValueDecl *value) { value->getDescriptiveKind(), value->getName(), overriddenIsolation); overridden->diagnose(diag::overridden_here); } + +static bool shouldDiagnoseExistingDataRaces(const DeclContext *dc) { + while (!dc->isModuleScopeContext()) { + if (auto closure = dyn_cast(dc)) { + // Async closures use concurrency features. + if (closure->getType() && closure->isBodyAsync()) + return true; + } else if (auto decl = dc->getAsDecl()) { + // If any isolation attributes are present, we're using concurrency + // features. + if (getIsolationFromAttributes(decl, /*shouldDiagnose=*/false)) + return true; + + if (auto func = dyn_cast(decl)) { + // Async functions use concurrency features. + if (func->hasAsync()) + return true; + + // If there is an explicit @asyncHandler, we're using concurrency + // features. + if (func->getAttrs().hasAttribute()) + return true; + } + } + + // If we're in an actor, we're using concurrency features. + if (auto classDecl = dc->getSelfClassDecl()) { + if (classDecl->isActor()) + return true; + } + + // Keep looking. + dc = dc->getParent(); + } + + return false; +} diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index bdb83d57dfe8c..f84a989b27c66 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -2015,6 +2015,20 @@ class CheckEffectsCoverage : public EffectsHandlingWalker } }; +// Find nested functions and perform effects checking on them. +struct LocalFunctionEffectsChecker : ASTWalker { + bool walkToDeclPre(Decl *D) override { + if (auto func = dyn_cast(D)) { + if (func->getDeclContext()->isLocalContext()) + TypeChecker::checkFunctionEffects(func); + + return false; + } + + return true; + } +}; + } // end anonymous namespace void TypeChecker::checkTopLevelEffects(TopLevelCodeDecl *code) { @@ -2026,6 +2040,7 @@ void TypeChecker::checkTopLevelEffects(TopLevelCodeDecl *code) { checker.setTopLevelThrowWithoutTry(); code->getBody()->walk(checker); + code->getBody()->walk(LocalFunctionEffectsChecker()); } void TypeChecker::checkFunctionEffects(AbstractFunctionDecl *fn) { @@ -2045,7 +2060,9 @@ void TypeChecker::checkFunctionEffects(AbstractFunctionDecl *fn) { if (auto body = fn->getBody()) { body->walk(checker); + body->walk(LocalFunctionEffectsChecker()); } + if (auto ctor = dyn_cast(fn)) if (auto superInit = ctor->getSuperInitCall()) superInit->walk(checker); @@ -2056,6 +2073,7 @@ void TypeChecker::checkInitializerEffects(Initializer *initCtx, auto &ctx = initCtx->getASTContext(); CheckEffectsCoverage checker(ctx, Context::forInitializer(initCtx)); init->walk(checker); + init->walk(LocalFunctionEffectsChecker()); } /// Check the correctness of effects within the given enum @@ -2070,6 +2088,7 @@ void TypeChecker::checkEnumElementEffects(EnumElementDecl *elt, Expr *E) { auto &ctx = elt->getASTContext(); CheckEffectsCoverage checker(ctx, Context::forEnumElementInitializer(elt)); E->walk(checker); + E->walk(LocalFunctionEffectsChecker()); } void TypeChecker::checkPropertyWrapperEffects( @@ -2077,6 +2096,7 @@ void TypeChecker::checkPropertyWrapperEffects( auto &ctx = binding->getASTContext(); CheckEffectsCoverage checker(ctx, Context::forPatternBinding(binding)); expr->walk(checker); + expr->walk(LocalFunctionEffectsChecker()); } bool TypeChecker::canThrow(Expr *expr) { diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index ee22cad4d2e6e..d1294051b10d6 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -2068,8 +2068,10 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &evaluator, performAbstractFuncDeclDiagnostics(AFD); TypeChecker::computeCaptures(AFD); - checkFunctionActorIsolation(AFD); - TypeChecker::checkFunctionEffects(AFD); + if (!AFD->getDeclContext()->isLocalContext()) { + checkFunctionActorIsolation(AFD); + TypeChecker::checkFunctionEffects(AFD); + } return hadError ? errorBody() : body; } diff --git a/test/Concurrency/actor_isolation.swift b/test/Concurrency/actor_isolation.swift index 25fc41059f01f..3571f702a6519 100644 --- a/test/Concurrency/actor_isolation.swift +++ b/test/Concurrency/actor_isolation.swift @@ -2,7 +2,7 @@ // REQUIRES: concurrency let immutableGlobal: String = "hello" -var mutableGlobal: String = "can't touch this" // expected-note 2{{var declared here}} +var mutableGlobal: String = "can't touch this" // expected-note 3{{var declared here}} func globalFunc() { } func acceptClosure(_: () -> T) { } @@ -162,6 +162,11 @@ extension MyActor { } } + acceptEscapingClosure { + localFn1() + localFn2() + } + localVar = 0 // Partial application @@ -303,4 +308,62 @@ func testGlobalRestrictions(actor: MyActor) async { // Operations on non-instances are permitted. MyActor.synchronousStatic() MyActor.synchronousClass() + + // Global mutable state cannot be accessed. + _ = mutableGlobal // expected-warning{{reference to var 'mutableGlobal' is not concurrency-safe because it involves shared mutable state}} + + // Local mutable variables cannot be accessed from concurrently-executing + // code. + var i = 17 // expected-note{{var declared here}} + acceptEscapingClosure { + i = 42 // expected-warning{{local var 'i' is unsafe to reference in code that may execute concurrently}} + } + print(i) +} + +// ---------------------------------------------------------------------- +// Local function isolation restrictions +// ---------------------------------------------------------------------- +func checkLocalFunctions() async { + var i = 0 + var j = 0 // expected-note{{var declared here}} + + func local1() { + i = 17 + } + + func local2() { + j = 42 // expected-warning{{local var 'j' is unsafe to reference in code that may execute concurrently}} + } + + // Okay to call locally. + local1() + local2() + + // Non-concurrent closures don't cause problems. + acceptClosure { + local1() + local2() + } + + // Escaping closures can make the local function execute concurrently. + acceptEscapingClosure { + local2() + } + + print(i) + print(j) + + var k = 17 // expected-note{{var declared here}} + func local4() { + acceptEscapingClosure { + local3() + } + } + + func local3() { + k = 25 // expected-warning{{local var 'k' is unsafe to reference in code that may execute concurrently}} + } + + print(k) } diff --git a/test/Concurrency/async_task_groups.swift b/test/Concurrency/async_task_groups.swift index 1eef802ee2701..d9b83f3c972d4 100644 --- a/test/Concurrency/async_task_groups.swift +++ b/test/Concurrency/async_task_groups.swift @@ -196,7 +196,7 @@ extension Collection { var submitted = 0 func submitNext() async throws { - await group.add { + await group.add { [submitted,i] in let value = await try transform(self[i]) return (submitted, value) }