diff --git a/include/swift/AST/AnyFunctionRef.h b/include/swift/AST/AnyFunctionRef.h index c287beb97dc42..ae31b67b5e10f 100644 --- a/include/swift/AST/AnyFunctionRef.h +++ b/include/swift/AST/AnyFunctionRef.h @@ -83,10 +83,6 @@ class AnyFunctionRef { TheFunction.get()->setCaptureInfo(captures); } - void getLocalCaptures(SmallVectorImpl &Result) const { - getCaptureInfo().getLocalCaptures(Result); - } - bool hasType() const { if (auto *AFD = TheFunction.dyn_cast()) return AFD->hasInterfaceType(); diff --git a/include/swift/AST/CaptureInfo.h b/include/swift/AST/CaptureInfo.h index 433b3d534221f..85559d6554ea7 100644 --- a/include/swift/AST/CaptureInfo.h +++ b/include/swift/AST/CaptureInfo.h @@ -197,15 +197,6 @@ class CaptureInfo { return StorageAndFlags.getPointer()->getCaptures(); } - /// Return a filtered list of the captures for this function, - /// filtering out global variables. This function returns the list that - /// actually needs to be closed over. - /// - void getLocalCaptures(SmallVectorImpl &Result) const; - - /// \returns true if getLocalCaptures() will return a non-empty list. - bool hasLocalCaptures() const; - /// \returns true if the function captures any generic type parameters. bool hasGenericParamCaptures() const { // FIXME: Ideally, everywhere that synthesizes a function should include diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 738d111cc266b..8a0247fb06a49 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -8019,10 +8019,6 @@ class FuncDecl : public AbstractFunctionDecl { /// prior to type checking. bool isBinaryOperator() const; - void getLocalCaptures(SmallVectorImpl &Result) const { - return getCaptureInfo().getLocalCaptures(Result); - } - ParamDecl **getImplicitSelfDeclStorage(); /// Get the supertype method this method overrides, if any. diff --git a/lib/AST/CaptureInfo.cpp b/lib/AST/CaptureInfo.cpp index 1eb91771d4cfa..a107699ccb5a5 100644 --- a/lib/AST/CaptureInfo.cpp +++ b/lib/AST/CaptureInfo.cpp @@ -59,34 +59,7 @@ CaptureInfo CaptureInfo::empty() { return result; } -bool CaptureInfo::hasLocalCaptures() const { - for (auto capture : getCaptures()) { - if (capture.isLocalCapture()) - return true; - } - return false; -} - - -void CaptureInfo:: -getLocalCaptures(SmallVectorImpl &Result) const { - if (!hasLocalCaptures()) return; - - Result.reserve(getCaptures().size()); - - // Filter out global variables. - for (auto capture : getCaptures()) { - if (!capture.isLocalCapture()) - continue; - - Result.push_back(capture); - } -} - VarDecl *CaptureInfo::getIsolatedParamCapture() const { - if (!hasLocalCaptures()) - return nullptr; - for (const auto &capture : getCaptures()) { // NOTE: isLocalCapture() returns false if we have dynamic self metadata // since dynamic self metadata is never an isolated capture. So we can just diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index b3a2029738a8f..ba6397cb1a561 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -4139,11 +4139,21 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { DynamicSelfType *capturesDynamicSelf = nullptr; OpaqueValueExpr *capturesOpaqueValue = nullptr; - std::function collectCaptures; + std::function collectCaptures; std::function collectFunctionCaptures; std::function collectConstantCaptures; - collectCaptures = [&](CaptureInfo captureInfo, DeclContext *dc) { + auto recordCapture = [&](CapturedValue capture) { + ValueDecl *value = capture.getDecl(); + auto existing = captures.find(value); + if (existing != captures.end()) { + existing->second = existing->second.mergeFlags(capture); + } else { + captures.insert(std::pair(value, capture)); + } + }; + + collectCaptures = [&](CaptureInfo captureInfo) { assert(captureInfo.hasBeenComputed()); if (captureInfo.hasGenericParamCaptures()) @@ -4153,9 +4163,10 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { if (captureInfo.hasOpaqueValueCapture()) capturesOpaqueValue = captureInfo.getOpaqueValue(); - SmallVector localCaptures; - captureInfo.getLocalCaptures(localCaptures); - for (auto capture : localCaptures) { + for (auto capture : captureInfo.getCaptures()) { + if (!capture.isLocalCapture()) + continue; + // If the capture is of another local function, grab its transitive // captures instead. if (auto capturedFn = getAnyFunctionRefFromCapture(capture)) { @@ -4287,13 +4298,7 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { } // Collect non-function captures. - ValueDecl *value = capture.getDecl(); - auto existing = captures.find(value); - if (existing != captures.end()) { - existing->second = existing->second.mergeFlags(capture); - } else { - captures.insert(std::pair(value, capture)); - } + recordCapture(capture); } }; @@ -4305,8 +4310,21 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { return; PrettyStackTraceAnyFunctionRef("lowering local captures", curFn); - auto dc = curFn.getAsDeclContext(); - collectCaptures(curFn.getCaptureInfo(), dc); + collectCaptures(curFn.getCaptureInfo()); + + if (auto *afd = curFn.getAbstractFunctionDecl()) { + // If a local function inherits isolation from the enclosing context, + // make sure we capture the isolated parameter, if we haven't already. + if (afd->isLocalCapture()) { + auto actorIsolation = getActorIsolation(afd); + if (actorIsolation.getKind() == ActorIsolation::ActorInstance) { + if (auto *var = actorIsolation.getActorInstance()) { + assert(isa(var)); + recordCapture(CapturedValue(var, 0, afd->getLoc())); + } + } + } + } // A function's captures also include its default arguments, because // when we reference a function we don't track which default arguments @@ -4317,7 +4335,7 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { if (auto *AFD = curFn.getAbstractFunctionDecl()) { for (auto *P : *AFD->getParameters()) { if (P->hasDefaultExpr()) - collectCaptures(P->getDefaultArgumentCaptureInfo(), dc); + collectCaptures(P->getDefaultArgumentCaptureInfo()); } } }; @@ -4330,10 +4348,8 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) { if (auto *afd = dyn_cast(curFn.getDecl())) { auto *param = getParameterAt(static_cast(afd), curFn.defaultArgIndex); - if (param->hasDefaultExpr()) { - auto dc = afd->getInnermostDeclContext(); - collectCaptures(param->getDefaultArgumentCaptureInfo(), dc); - } + if (param->hasDefaultExpr()) + collectCaptures(param->getDefaultArgumentCaptureInfo()); return; } diff --git a/lib/SILGen/SILGenConcurrency.cpp b/lib/SILGen/SILGenConcurrency.cpp index cf06eb943e309..c46093cfd231d 100644 --- a/lib/SILGen/SILGenConcurrency.cpp +++ b/lib/SILGen/SILGenConcurrency.cpp @@ -107,8 +107,8 @@ void SILGenFunction::emitExpectedExecutor() { // completely. if (F.isAsync() || (wantDataRaceChecks && funcDecl->isLocalCapture())) { - if (auto isolatedParam = funcDecl->getCaptureInfo() - .getIsolatedParamCapture()) { + auto loweredCaptures = SGM.Types.getLoweredLocalCaptures(SILDeclRef(funcDecl)); + if (auto isolatedParam = loweredCaptures.getIsolatedParamCapture()) { loadExpectedExecutorForLocalVar(isolatedParam); } else { auto loc = RegularLocation::getAutoGeneratedLocation(F.getLocation()); diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 580ad3ad51e14..279e9a4835aba 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -262,10 +262,6 @@ class SILGenVTable : public SILVTableVisitor { void emitVTable() { PrettyStackTraceDecl("silgen emitVTable", theClass); - // Imported types don't have vtables right now. - if (theClass->hasClangNode()) - return; - // Populate our list of base methods and overrides. visitAncestor(theClass); @@ -317,6 +313,10 @@ class SILGenVTable : public SILVTableVisitor { } void visitAncestor(ClassDecl *ancestor) { + // Imported types don't have vtables right now. + if (ancestor->hasClangNode()) + return; + auto *superDecl = ancestor->getSuperclassDecl(); if (superDecl) visitAncestor(superDecl); @@ -1153,8 +1153,10 @@ class SILGenType : public TypeMemberVisitor { // Build a vtable if this is a class. if (auto theClass = dyn_cast(theType)) { - SILGenVTable genVTable(SGM, theClass); - genVTable.emitVTable(); + if (!theClass->hasClangNode()) { + SILGenVTable genVTable(SGM, theClass); + genVTable.emitVTable(); + } } // If this is a nominal type that is move only, emit a deinit table for it. diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index a1cd24887d27b..58980e9085f33 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -475,7 +475,6 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, original->getInterfaceType()->castTo()); if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) { - assert(original->getCaptureInfo().hasLocalCaptures()); silParameterIndices = silParameterIndices->extendingCapacity(original->getASTContext(), parameterIndices->getCapacity()); diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp index fc78a38abc141..ec96ca43715d2 100644 --- a/lib/Sema/TypeCheckConcurrency.cpp +++ b/lib/Sema/TypeCheckConcurrency.cpp @@ -2478,9 +2478,9 @@ namespace { /// Check closure captures for Sendable violations. void checkLocalCaptures(AnyFunctionRef localFunc) { - SmallVector captures; - localFunc.getCaptureInfo().getLocalCaptures(captures); - for (const auto &capture : captures) { + for (const auto &capture : localFunc.getCaptureInfo().getCaptures()) { + if (!capture.isLocalCapture()) + continue; if (capture.isDynamicSelfMetadata()) continue; if (capture.isOpaqueValue()) @@ -5174,9 +5174,7 @@ ActorIsolation ActorIsolationRequest::evaluate( llvm_unreachable("context cannot have erased isolation"); case ActorIsolation::ActorInstance: - if (auto param = func->getCaptureInfo().getIsolatedParamCapture()) - return inferredIsolation(enclosingIsolation); - break; + return inferredIsolation(enclosingIsolation); case ActorIsolation::GlobalActor: return inferredIsolation(enclosingIsolation); diff --git a/test/Concurrency/actor_isolation.swift b/test/Concurrency/actor_isolation.swift index 859c75a63693f..e1493c22197b5 100644 --- a/test/Concurrency/actor_isolation.swift +++ b/test/Concurrency/actor_isolation.swift @@ -755,6 +755,8 @@ func checkLocalFunctions() async { print(k) } +func callee(_: () -> ()) {} + @available(SwiftStdlib 5.1, *) actor LocalFunctionIsolatedActor { func a() -> Bool { // expected-note{{calls to instance method 'a()' from outside of its actor context are implicitly asynchronous}} @@ -774,6 +776,30 @@ actor LocalFunctionIsolatedActor { } return c() } + + func hasRecursiveLocalFunction() { + func recursiveLocalFunction(n: Int) { + _ = a() + callee { _ = a() } + if n > 0 { recursiveLocalFunction(n: n - 1) } + } + + recursiveLocalFunction(n: 10) + } + + func hasRecursiveLocalFunctions() { + recursiveLocalFunction() + + func recursiveLocalFunction() { + anotherRecursiveLocalFunction() + } + + func anotherRecursiveLocalFunction() { + callee { _ = a() } + _ = a() + } + } + } // ---------------------------------------------------------------------- diff --git a/test/SILGen/local_function_isolation.swift b/test/SILGen/local_function_isolation.swift new file mode 100644 index 0000000000000..87580d75d6e10 --- /dev/null +++ b/test/SILGen/local_function_isolation.swift @@ -0,0 +1,55 @@ +// RUN: %target-swift-frontend -emit-silgen %s -disable-availability-checking | %FileCheck %s + +// REQUIRES: concurrency + +class NotSendable {} + +func callee(_ ns: NotSendable) {} + +actor MyActor { + func isolatedToSelf(ns: NotSendable) { + // CHECK-LABEL: sil private [ossa] @$s24local_function_isolation7MyActorC14isolatedToSelf2nsyAA11NotSendableC_tF08implicitH7CaptureL_yyYaF : $@convention(thin) @async (@guaranteed NotSendable, @sil_isolated @guaranteed MyActor) -> () { + func implicitSelfCapture() async { + + // CHECK: [[COPY:%.*]] = copy_value %1 : $MyActor + // CHECK-NEXT: [[BORROW:%.*]] = begin_borrow [[COPY]] : $MyActor + // CHECK-NEXT: hop_to_executor [[BORROW]] : $MyActor + + // CHECK: [[FN:%.*]] = function_ref @$s24local_function_isolation4testyyYaF : $@convention(thin) @async () -> () + // CHECK-NEXT: apply [[FN]]() : $@convention(thin) @async () -> () + await test() + + // CHECK: hop_to_executor [[BORROW]] : $MyActor + // CHECK: [[FN:%.*]] = function_ref @$s24local_function_isolation6calleeyyAA11NotSendableCF : $@convention(thin) (@guaranteed NotSendable) -> () + // CHECK-NEXT: apply [[FN]](%0) : $@convention(thin) (@guaranteed NotSendable) -> () + + // we need to hop back to 'self' here + callee(ns) + + // CHECK: end_borrow [[BORROW]] : $MyActor + // CHECK-NEXT: destroy_value [[COPY]] : $MyActor + } + } +} + +func f(isolation: isolated MyActor, ns: NotSendable) { + // CHECK-LABEL: sil private [ossa] @$s24local_function_isolation1f0C02nsyAA7MyActorCYi_AA11NotSendableCtF23implicitIsolatedCaptureL_yyYaF : $@convention(thin) @async (@guaranteed NotSendable, @sil_isolated @guaranteed MyActor) -> () { + func implicitIsolatedCapture() async { + + // CHECK: [[COPY:%.*]] = copy_value %1 : $MyActor + // CHECK-NEXT: [[BORROW:%.*]] = begin_borrow [[COPY]] : $MyActor + // CHECK-NEXT: hop_to_executor [[BORROW]] : $MyActor + + // CHECK: [[FN:%.*]] = function_ref @$s24local_function_isolation4testyyYaF : $@convention(thin) @async () -> () + // CHECK-NEXT: apply [[FN]]() : $@convention(thin) @async () -> () + await test() + + // we need to hop back to 'isolation' here + callee(ns) + + // CHECK: end_borrow [[BORROW]] : $MyActor + // CHECK-NEXT: destroy_value [[COPY]] : $MyActor + } +} + +func test() async {}