From c5c33d6c5a278cb659a1a3f051b7e00f2ab0e458 Mon Sep 17 00:00:00 2001 From: Egor Zhdan Date: Fri, 22 Aug 2025 18:18:14 +0100 Subject: [PATCH] [cxx-interop] Allow retain/release operations to be methods Some foreign reference types such as IUnknown define retain/release operations as methods of the type. Previously Swift only supported retain/release operations as standalone functions. The syntax for member functions would be `SWIFT_SHARED_REFERENCE(.doRetain, .doRelease)`. rdar://160696723 (cherry picked from commit e78ce6165f51078a233f4da17eab87c3a74f0cac) --- .../swift/AST/DiagnosticsClangImporter.def | 3 + lib/ClangImporter/ClangImporter.cpp | 25 +++- lib/ClangImporter/ImportDecl.cpp | 46 +++++-- lib/ClangImporter/ImporterImpl.h | 2 +- lib/ClangImporter/SwiftDeclSynthesizer.cpp | 3 +- lib/IRGen/GenObjC.cpp | 5 + .../Inputs/lifetime-operation-methods.h | 121 ++++++++++++++++++ .../foreign-reference/Inputs/module.modulemap | 5 + ...e-operation-methods-module-interface.swift | 49 +++++++ ...fetime-operation-methods-typechecker.swift | 6 + .../lifetime-operation-methods.swift | 76 +++++++++++ 11 files changed, 325 insertions(+), 16 deletions(-) create mode 100644 test/Interop/Cxx/foreign-reference/Inputs/lifetime-operation-methods.h create mode 100644 test/Interop/Cxx/foreign-reference/lifetime-operation-methods-module-interface.swift create mode 100644 test/Interop/Cxx/foreign-reference/lifetime-operation-methods-typechecker.swift create mode 100644 test/Interop/Cxx/foreign-reference/lifetime-operation-methods.swift diff --git a/include/swift/AST/DiagnosticsClangImporter.def b/include/swift/AST/DiagnosticsClangImporter.def index 6befce05c280a..0bfb808c03dd0 100644 --- a/include/swift/AST/DiagnosticsClangImporter.def +++ b/include/swift/AST/DiagnosticsClangImporter.def @@ -265,6 +265,9 @@ ERROR(foreign_reference_types_release_non_void_return_type, none, ERROR(foreign_reference_types_retain_release_not_a_function_decl, none, "specified %select{retain|release}0 function '%1' is not a function", (bool, StringRef)) +ERROR(foreign_reference_types_retain_release_not_an_instance_function, none, + "specified %select{retain|release}0 function '%1' is a static function; expected an instance function", + (bool, StringRef)) ERROR(conforms_to_missing_dot, none, "expected module name and protocol name separated by '.' in protocol " "conformance; '%0' is invalid", diff --git a/lib/ClangImporter/ClangImporter.cpp b/lib/ClangImporter/ClangImporter.cpp index ae91fe04fb822..47361b38a0f0f 100644 --- a/lib/ClangImporter/ClangImporter.cpp +++ b/lib/ClangImporter/ClangImporter.cpp @@ -7876,10 +7876,27 @@ getRefParentDecls(const clang::RecordDecl *decl, ASTContext &ctx, } llvm::SmallVector -importer::getValueDeclsForName( - const clang::Decl *decl, ASTContext &ctx, StringRef name) { +importer::getValueDeclsForName(NominalTypeDecl *decl, StringRef name) { + // If the name is empty, don't try to find any decls. + if (name.empty()) + return {}; + + auto &ctx = decl->getASTContext(); + auto clangDecl = decl->getClangDecl(); llvm::SmallVector results; - auto *clangMod = decl->getOwningModule(); + + if (name.starts_with(".")) { + // Look for a member of decl instead of a global. + StringRef memberName = name.drop_front(1); + if (memberName.empty()) + return {}; + auto declName = DeclName(ctx.getIdentifier(memberName)); + auto allResults = evaluateOrDefault( + ctx.evaluator, ClangRecordMemberLookup({decl, declName}), {}); + return SmallVector(allResults.begin(), allResults.end()); + } + + auto *clangMod = clangDecl->getOwningModule(); if (clangMod && clangMod->isSubModule()) clangMod = clangMod->getTopLevelModule(); if (clangMod) { @@ -8487,7 +8504,7 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate( return {CustomRefCountingOperationResult::immortal, nullptr, name}; llvm::SmallVector results = - getValueDeclsForName(swiftDecl->getClangDecl(), ctx, name); + getValueDeclsForName(const_cast(swiftDecl), name); if (results.size() == 1) return {CustomRefCountingOperationResult::foundOperation, results.front(), name}; diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp index 8548ef054c5bd..b648c1fafccd4 100644 --- a/lib/ClangImporter/ImportDecl.cpp +++ b/lib/ClangImporter/ImportDecl.cpp @@ -2747,6 +2747,7 @@ namespace { enum class RetainReleaseOperationKind { notAfunction, + notAnInstanceFunction, invalidReturnType, invalidParameters, valid @@ -2760,17 +2761,32 @@ namespace { if (!operationFn) return RetainReleaseOperationKind::notAfunction; - if (operationFn->getParameters()->size() != 1) - return RetainReleaseOperationKind::invalidParameters; + if (operationFn->isStatic()) + return RetainReleaseOperationKind::notAnInstanceFunction; - Type paramType = - operationFn->getParameters()->get(0)->getInterfaceType(); - // Unwrap if paramType is an OptionalType - if (Type optionalType = paramType->getOptionalObjectType()) { - paramType = optionalType; - } + if (operationFn->isInstanceMember()) { + if (operationFn->getParameters()->size() != 0) + return RetainReleaseOperationKind::invalidParameters; + } else { + if (operationFn->getParameters()->size() != 1) + return RetainReleaseOperationKind::invalidParameters; + } + + Type paramType; + NominalTypeDecl *paramDecl = nullptr; + if (!operationFn->isInstanceMember()) { + paramType = + operationFn->getParameters()->get(0)->getInterfaceType(); + // Unwrap if paramType is an OptionalType + if (Type optionalType = paramType->getOptionalObjectType()) { + paramType = optionalType; + } - swift::NominalTypeDecl *paramDecl = paramType->getAnyNominal(); + paramDecl = paramType->getAnyNominal(); + } else { + paramDecl = cast(operationFn->getParent()); + paramType = paramDecl->getDeclaredInterfaceType(); + } // The return type should be void (for release functions), or void // or the parameter type (for retain functions). @@ -2855,6 +2871,12 @@ namespace { diag::foreign_reference_types_retain_release_not_a_function_decl, false, retainOperation.name); break; + case RetainReleaseOperationKind::notAnInstanceFunction: + Impl.diagnose( + loc, + diag::foreign_reference_types_retain_release_not_an_instance_function, + false, retainOperation.name); + break; case RetainReleaseOperationKind::invalidReturnType: Impl.diagnose( loc, @@ -2920,6 +2942,12 @@ namespace { diag::foreign_reference_types_retain_release_not_a_function_decl, true, releaseOperation.name); break; + case RetainReleaseOperationKind::notAnInstanceFunction: + Impl.diagnose( + loc, + diag::foreign_reference_types_retain_release_not_an_instance_function, + true, releaseOperation.name); + break; case RetainReleaseOperationKind::invalidReturnType: Impl.diagnose( loc, diff --git a/lib/ClangImporter/ImporterImpl.h b/lib/ClangImporter/ImporterImpl.h index 1171830f753f1..1e7944e13c971 100644 --- a/lib/ClangImporter/ImporterImpl.h +++ b/lib/ClangImporter/ImporterImpl.h @@ -2151,7 +2151,7 @@ ImportedType findOptionSetEnum(clang::QualType type, /// /// The name we're looking for is the Swift name. llvm::SmallVector -getValueDeclsForName(const clang::Decl *decl, ASTContext &ctx, StringRef name); +getValueDeclsForName(NominalTypeDecl* decl, StringRef name); } // end namespace importer } // end namespace swift diff --git a/lib/ClangImporter/SwiftDeclSynthesizer.cpp b/lib/ClangImporter/SwiftDeclSynthesizer.cpp index 308c2ab121f0f..551aca1d5c944 100644 --- a/lib/ClangImporter/SwiftDeclSynthesizer.cpp +++ b/lib/ClangImporter/SwiftDeclSynthesizer.cpp @@ -2767,8 +2767,7 @@ FuncDecl *SwiftDeclSynthesizer::findExplicitDestroy( if (!destroyFuncName.consume_front("destroy:")) continue; - auto decls = getValueDeclsForName( - clangType, nominal->getASTContext(), destroyFuncName); + auto decls = getValueDeclsForName(nominal, destroyFuncName); for (auto decl : decls) { auto func = dyn_cast(decl); if (!func) diff --git a/lib/IRGen/GenObjC.cpp b/lib/IRGen/GenObjC.cpp index 0dd45be382692..6b27a6e7b855e 100644 --- a/lib/IRGen/GenObjC.cpp +++ b/lib/IRGen/GenObjC.cpp @@ -1722,6 +1722,11 @@ void IRGenFunction::emitBlockRelease(llvm::Value *value) { void IRGenFunction::emitForeignReferenceTypeLifetimeOperation( ValueDecl *fn, llvm::Value *value, bool needsNullCheck) { + if (auto originalDecl = fn->getASTContext() + .getClangModuleLoader() + ->getOriginalForClonedMember(fn)) + fn = originalDecl; + assert(fn->getClangDecl() && isa(fn->getClangDecl())); auto clangFn = cast(fn->getClangDecl()); diff --git a/test/Interop/Cxx/foreign-reference/Inputs/lifetime-operation-methods.h b/test/Interop/Cxx/foreign-reference/Inputs/lifetime-operation-methods.h new file mode 100644 index 0000000000000..b101f0998e859 --- /dev/null +++ b/test/Interop/Cxx/foreign-reference/Inputs/lifetime-operation-methods.h @@ -0,0 +1,121 @@ +#include + +struct RefCountedBox { + int value; + int refCount = 1; + + RefCountedBox(int value) : value(value) {} + + void doRetain() { refCount++; } + void doRelease() { refCount--; } +} SWIFT_SHARED_REFERENCE(.doRetain, .doRelease); + +struct DerivedRefCountedBox : RefCountedBox { + int secondValue = 1; + DerivedRefCountedBox(int value, int secondValue) + : RefCountedBox(value), secondValue(secondValue) {} +}; + +// MARK: Retain in a base type, release in derived + +struct BaseHasRetain { + mutable int refCount = 1; + void doRetainInBase() const { refCount++; } +}; + +struct DerivedHasRelease : BaseHasRetain { + int value; + DerivedHasRelease(int value) : value(value) {} + + void doRelease() const { refCount--; } +} SWIFT_SHARED_REFERENCE(.doRetainInBase, .doRelease); + +// MARK: Retain in a base type, release in templated derived + +template +struct TemplatedDerivedHasRelease : BaseHasRetain { + T value; + TemplatedDerivedHasRelease(T value) : value(value) {} + + void doReleaseTemplated() const { refCount--; } +} SWIFT_SHARED_REFERENCE(.doRetainInBase, .doReleaseTemplated); + +using TemplatedDerivedHasReleaseFloat = TemplatedDerivedHasRelease; +using TemplatedDerivedHasReleaseInt = TemplatedDerivedHasRelease; + +// MARK: Retain/release in CRTP base type + +template +struct CRTPBase { + mutable int refCount = 1; + void crtpRetain() const { refCount++; } + void crtpRelease() const { refCount--; } +} SWIFT_SHARED_REFERENCE(.crtpRetain, .crtpRelease); + +struct CRTPDerived : CRTPBase { + int value; + CRTPDerived(int value) : value(value) {} +}; + +// MARK: Virtual retain and release + +struct VirtualRetainRelease { + int value; + mutable int refCount = 1; + VirtualRetainRelease(int value) : value(value) {} + + virtual void doRetainVirtual() const { refCount++; } + virtual void doReleaseVirtual() const { refCount--; } + virtual ~VirtualRetainRelease() = default; +} SWIFT_SHARED_REFERENCE(.doRetainVirtual, .doReleaseVirtual); + +struct DerivedVirtualRetainRelease : VirtualRetainRelease { + DerivedVirtualRetainRelease(int value) : VirtualRetainRelease(value) {} + + mutable bool calledDerived = false; + void doRetainVirtual() const override { refCount++; calledDerived = true; } + void doReleaseVirtual() const override { refCount--; } +}; + +// MARK: Pure virtual retain and release + +struct PureVirtualRetainRelease { + int value; + mutable int refCount = 1; + PureVirtualRetainRelease(int value) : value(value) {} + + virtual void doRetainPure() const = 0; + virtual void doReleasePure() const = 0; + virtual ~PureVirtualRetainRelease() = default; +} SWIFT_SHARED_REFERENCE(.doRetainPure, .doReleasePure); + +struct DerivedPureVirtualRetainRelease : PureVirtualRetainRelease { + mutable int refCount = 1; + + DerivedPureVirtualRetainRelease(int value) : PureVirtualRetainRelease(value) {} + void doRetainPure() const override { refCount++; } + void doReleasePure() const override { refCount--; } +}; + +// MARK: Static retain/release +#ifdef INCORRECT +struct StaticRetainRelease { +// expected-error@-1 {{specified retain function '.staticRetain' is a static function; expected an instance function}} +// expected-error@-2 {{specified release function '.staticRelease' is a static function; expected an instance function}} + int value; + int refCount = 1; + + StaticRetainRelease(int value) : value(value) {} + + static void staticRetain(StaticRetainRelease* o) { o->refCount++; } + static void staticRelease(StaticRetainRelease* o) { o->refCount--; } +} SWIFT_SHARED_REFERENCE(.staticRetain, .staticRelease); + +struct DerivedStaticRetainRelease : StaticRetainRelease { +// expected-error@-1 {{cannot find retain function '.staticRetain' for reference type 'DerivedStaticRetainRelease'}} +// expected-error@-2 {{cannot find release function '.staticRelease' for reference type 'DerivedStaticRetainRelease'}} + int secondValue = 1; + DerivedStaticRetainRelease(int value, int secondValue) + : StaticRetainRelease(value), secondValue(secondValue) {} +}; +#endif diff --git a/test/Interop/Cxx/foreign-reference/Inputs/module.modulemap b/test/Interop/Cxx/foreign-reference/Inputs/module.modulemap index 40edefa638bd8..6c903dd6b118f 100644 --- a/test/Interop/Cxx/foreign-reference/Inputs/module.modulemap +++ b/test/Interop/Cxx/foreign-reference/Inputs/module.modulemap @@ -44,6 +44,11 @@ module ReferenceCountedObjCProperty { export * } +module LifetimeOperationMethods { + header "lifetime-operation-methods.h" + requires cplusplus +} + module MemberLayout { header "member-layout.h" requires cplusplus diff --git a/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-module-interface.swift b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-module-interface.swift new file mode 100644 index 0000000000000..73ad9a291ce11 --- /dev/null +++ b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-module-interface.swift @@ -0,0 +1,49 @@ +// RUN: %target-swift-ide-test -print-module -cxx-interoperability-mode=upcoming-swift -I %swift_src_root/lib/ClangImporter/SwiftBridging -module-to-print=LifetimeOperationMethods -I %S/Inputs -source-filename=x | %FileCheck %s + +// CHECK: class RefCountedBox { +// CHECK: func doRetain() +// CHECK: func doRelease() +// CHECK: } +// CHECK: class DerivedRefCountedBox { +// CHECK: func doRetain() +// CHECK: func doRelease() +// CHECK: } + +// CHECK: class DerivedHasRelease { +// CHECK: func doRelease() +// CHECK: func doRetainInBase() +// CHECK: } + +// CHECK: class TemplatedDerivedHasRelease { +// CHECK: var value: Float +// CHECK: func doReleaseTemplated() +// CHECK: func doRetainInBase() +// CHECK: } +// CHECK: class TemplatedDerivedHasRelease { +// CHECK: var value: Int32 +// CHECK: func doReleaseTemplated() +// CHECK: func doRetainInBase() +// CHECK: } + +// CHECK: class CRTPDerived { +// CHECK: var value: Int32 +// CHECK: } + +// CHECK: class VirtualRetainRelease { +// CHECK: func doRetainVirtual() +// CHECK: func doReleaseVirtual() +// CHECK: } +// CHECK: class DerivedVirtualRetainRelease { +// CHECK: func doRetainVirtual() +// CHECK: func doReleaseVirtual() +// CHECK: } + +// CHECK: class PureVirtualRetainRelease { +// CHECK: func doRetainPure() +// CHECK: func doReleasePure() +// CHECK: } +// CHECK: class DerivedPureVirtualRetainRelease { +// CHECK: func doRetainPure() +// CHECK: func doReleasePure() +// CHECK: var refCount: Int32 +// CHECK: } diff --git a/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-typechecker.swift b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-typechecker.swift new file mode 100644 index 0000000000000..9241a939159da --- /dev/null +++ b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods-typechecker.swift @@ -0,0 +1,6 @@ +// RUN: %target-typecheck-verify-swift -Xcc -DINCORRECT -I %S%{fs-sep}Inputs -I %swift_src_root/lib/ClangImporter/SwiftBridging -verify-additional-file %S%{fs-sep}Inputs%{fs-sep}lifetime-operation-methods.h -cxx-interoperability-mode=upcoming-swift -disable-availability-checking + +import LifetimeOperationMethods + +let _ = StaticRetainRelease(123) +let _ = DerivedStaticRetainRelease(123, 456) diff --git a/test/Interop/Cxx/foreign-reference/lifetime-operation-methods.swift b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods.swift new file mode 100644 index 0000000000000..64f7c700e5a54 --- /dev/null +++ b/test/Interop/Cxx/foreign-reference/lifetime-operation-methods.swift @@ -0,0 +1,76 @@ +// RUN: %target-run-simple-swift(-I %S/Inputs -cxx-interoperability-mode=upcoming-swift -I %swift_src_root/lib/ClangImporter/SwiftBridging -Xfrontend -disable-availability-checking) + +// Temporarily disable when running with an older runtime (rdar://128681137) +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +import StdlibUnittest +import LifetimeOperationMethods + +var LifetimeMethodsTestSuite = TestSuite("Lifetime operations that are instance methods") + +LifetimeMethodsTestSuite.test("retain/release methods") { + let a = RefCountedBox(123) + expectEqual(a.value, 123) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +LifetimeMethodsTestSuite.test("retain/release methods from base type") { + let a = DerivedRefCountedBox(321, 456) + expectEqual(a.value, 321) + expectEqual(a.secondValue, 456) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number + + a.secondValue = 789 + expectEqual(a.secondValue, 789) +} + +LifetimeMethodsTestSuite.test("retain in base type, release in derived type") { + let a = DerivedHasRelease(321) + expectEqual(a.value, 321) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +LifetimeMethodsTestSuite.test("retain in base type, release in derived templated type") { + let a = TemplatedDerivedHasReleaseInt(456) + expectEqual(a.value, 456) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number + + let b = TemplatedDerivedHasReleaseFloat(5.66) + expectEqual(b.value, 5.66) +} + +LifetimeMethodsTestSuite.test("CRTP") { + let a = CRTPDerived(789) + expectEqual(a.value, 789) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +LifetimeMethodsTestSuite.test("virtual retain/release") { + let a = VirtualRetainRelease(456) + expectEqual(a.value, 456) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +LifetimeMethodsTestSuite.test("overridden virtual retain/release") { + let a = DerivedVirtualRetainRelease(456) + expectEqual(a.value, 456) + expectTrue(a.calledDerived) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +LifetimeMethodsTestSuite.test("overridden pure virtual retain/release") { + let a = DerivedPureVirtualRetainRelease(789) + expectEqual(a.value, 789) + expectTrue(a.refCount > 0) + expectTrue(a.refCount < 10) // optimizations would affect the exact number +} + +runAllTests()