From f4cf9146e4bf0556de8a8c9572986630a721aa45 Mon Sep 17 00:00:00 2001 From: susmonteiro Date: Tue, 23 Sep 2025 14:23:41 +0100 Subject: [PATCH] [cxx-interop] Add COPYABLE_IF macro Refactor `ClangTypeEscapability::evaluate` so that we can reuse part of it for `CxxValueSemantics::evaluate` --- .../ClangImporter/ClangImporterRequests.h | 1 + lib/ClangImporter/ClangImporter.cpp | 210 +++++++++++------- .../SwiftBridging/swift/bridging | 6 + .../Cxx/class/noncopyable-typechecker.swift | 54 ++++- 4 files changed, 185 insertions(+), 86 deletions(-) diff --git a/include/swift/ClangImporter/ClangImporterRequests.h b/include/swift/ClangImporter/ClangImporterRequests.h index cebaf06ee8a63..fe83d0fae0405 100644 --- a/include/swift/ClangImporter/ClangImporterRequests.h +++ b/include/swift/ClangImporter/ClangImporterRequests.h @@ -576,6 +576,7 @@ SourceLoc extractNearestSourceLoc(EscapabilityLookupDescriptor desc); // the object’s storage. This means reference types can be imported as // copyable to Swift, even when they are non-copyable in C++. enum class CxxValueSemanticsKind { + Unknown, Copyable, MoveOnly, // A record that is either not copyable/movable or not destructible. diff --git a/lib/ClangImporter/ClangImporter.cpp b/lib/ClangImporter/ClangImporter.cpp index 1cfc0429d5aaa..cfec5e3776460 100644 --- a/lib/ClangImporter/ClangImporter.cpp +++ b/lib/ClangImporter/ClangImporter.cpp @@ -5303,16 +5303,66 @@ static const llvm::StringMap> STLConditionalParams{ {"unordered_multimap", {0, 1}}, }; +template +static std::optional checkConditionalParams( + clang::RecordDecl *recordDecl, const std::vector &STLParams, + std::set &conditionalParams, + std::function(clang::TemplateArgument &, StringRef)> + &checkArg) { + auto specDecl = cast(recordDecl); + SmallVector, 4> argumentsToCheck; + bool hasInjectedSTLAnnotation = !STLParams.empty(); + while (specDecl) { + auto templateDecl = specDecl->getSpecializedTemplate(); + if (hasInjectedSTLAnnotation) { + auto params = templateDecl->getTemplateParameters(); + for (auto idx : STLParams) + argumentsToCheck.push_back( + std::make_pair(idx, params->getParam(idx)->getName())); + } else { + for (auto [idx, param] : + llvm::enumerate(*templateDecl->getTemplateParameters())) { + if (conditionalParams.erase(param->getName())) + argumentsToCheck.push_back(std::make_pair(idx, param->getName())); + } + } + auto &argList = specDecl->getTemplateArgs(); + for (auto argToCheck : argumentsToCheck) { + auto arg = argList[argToCheck.first]; + llvm::SmallVector nonPackArgs; + if (arg.getKind() == clang::TemplateArgument::Pack) { + auto pack = arg.getPackAsArray(); + nonPackArgs.assign(pack.begin(), pack.end()); + } else + nonPackArgs.push_back(arg); + for (auto nonPackArg : nonPackArgs) { + auto result = checkArg(nonPackArg, argToCheck.second); + if (result.has_value()) + return result.value(); + } + } + if (hasInjectedSTLAnnotation) + break; + clang::DeclContext *dc = specDecl; + specDecl = nullptr; + while ((dc = dc->getParent())) { + specDecl = dyn_cast(dc); + if (specDecl) + break; + } + } + return std::nullopt; +} + static std::set -getConditionalEscapableAttrParams(const clang::RecordDecl *decl) { +getConditionalAttrParams(const clang::RecordDecl *decl, StringRef attrName) { std::set result; if (!decl->hasAttrs()) return result; for (auto attr : decl->getAttrs()) { - if (auto swiftAttr = dyn_cast(attr)) - if (swiftAttr->getAttribute().starts_with("escapable_if:")) { - StringRef params = swiftAttr->getAttribute().drop_front( - StringRef("escapable_if:").size()); + if (auto swiftAttr = dyn_cast(attr)) { + StringRef params = swiftAttr->getAttribute(); + if (params.consume_front(attrName)) { auto commaPos = params.find(','); StringRef nextParam = params.take_front(commaPos); while (!nextParam.empty() && commaPos != StringRef::npos) { @@ -5322,10 +5372,21 @@ getConditionalEscapableAttrParams(const clang::RecordDecl *decl) { nextParam = params.take_front(commaPos); } } + } } return result; } +static std::set +getConditionalEscapableAttrParams(const clang::RecordDecl *decl) { + return getConditionalAttrParams(decl, "escapable_if:"); +} + +static std::set +getConditionalCopyableAttrParams(const clang::RecordDecl *decl) { + return getConditionalAttrParams(decl, "copyable_if:"); +} + CxxEscapability ClangTypeEscapability::evaluate(Evaluator &evaluator, EscapabilityLookupDescriptor desc) const { @@ -5351,60 +5412,33 @@ ClangTypeEscapability::evaluate(Evaluator &evaluator, recordDecl->isInStdNamespace() ? STLConditionalParams.find(recordDecl->getName()) : STLConditionalParams.end(); - bool hasInjectedSTLAnnotation = - injectedStlAnnotation != STLConditionalParams.end(); + auto STLParams = injectedStlAnnotation != STLConditionalParams.end() + ? injectedStlAnnotation->second + : std::vector(); auto conditionalParams = getConditionalEscapableAttrParams(recordDecl); - if (!conditionalParams.empty() || hasInjectedSTLAnnotation) { - auto specDecl = cast(recordDecl); - SmallVector, 4> argumentsToCheck; + + if (!STLParams.empty() || !conditionalParams.empty()) { HeaderLoc loc{recordDecl->getLocation()}; - while (specDecl) { - auto templateDecl = specDecl->getSpecializedTemplate(); - if (hasInjectedSTLAnnotation) { - auto params = templateDecl->getTemplateParameters(); - for (auto idx : injectedStlAnnotation->second) - argumentsToCheck.push_back( - std::make_pair(idx, params->getParam(idx)->getName())); - } else { - for (auto [idx, param] : - llvm::enumerate(*templateDecl->getTemplateParameters())) { - if (conditionalParams.erase(param->getName())) - argumentsToCheck.push_back(std::make_pair(idx, param->getName())); - } + std::function checkArgEscapability = + [&](clang::TemplateArgument &arg, + StringRef argToCheck) -> std::optional { + if (arg.getKind() != clang::TemplateArgument::Type && desc.impl) { + desc.impl->diagnose(loc, diag::type_template_parameter_expected, + argToCheck); + return CxxEscapability::Unknown; } - auto &argList = specDecl->getTemplateArgs(); - for (auto argToCheck : argumentsToCheck) { - auto arg = argList[argToCheck.first]; - llvm::SmallVector nonPackArgs; - if (arg.getKind() == clang::TemplateArgument::Pack) { - auto pack = arg.getPackAsArray(); - nonPackArgs.assign(pack.begin(), pack.end()); - } else - nonPackArgs.push_back(arg); - for (auto nonPackArg : nonPackArgs) { - if (nonPackArg.getKind() != clang::TemplateArgument::Type && - desc.impl) { - desc.impl->diagnose(loc, diag::type_template_parameter_expected, - argToCheck.second); - return CxxEscapability::Unknown; - } - auto argEscapability = evaluateEscapability( - nonPackArg.getAsType()->getUnqualifiedDesugaredType()); - if (argEscapability == CxxEscapability::NonEscapable) - return CxxEscapability::NonEscapable; - } - } - if (hasInjectedSTLAnnotation) - break; - clang::DeclContext *dc = specDecl; - specDecl = nullptr; - while ((dc = dc->getParent())) { - specDecl = dyn_cast(dc); - if (specDecl) - break; - } - } + auto argEscapability = evaluateEscapability( + arg.getAsType()->getUnqualifiedDesugaredType()); + if (argEscapability == CxxEscapability::NonEscapable) + return CxxEscapability::NonEscapable; + return std::nullopt; + }; + + auto result = checkConditionalParams( + recordDecl, STLParams, conditionalParams, checkArgEscapability); + if (result.has_value()) + return result.value(); if (desc.impl) for (auto name : conditionalParams) @@ -8343,36 +8377,48 @@ CxxValueSemantics::evaluate(Evaluator &evaluator, if (recordDecl->getIdentifier() && recordDecl->getName() == "_Optional_construct_base") return CxxValueSemanticsKind::Copyable; + } - auto injectedStlAnnotation = - STLConditionalParams.find(recordDecl->getName()); - - if (injectedStlAnnotation != STLConditionalParams.end()) { - auto specDecl = cast(recordDecl); - auto &argList = specDecl->getTemplateArgs(); - for (auto argToCheck : injectedStlAnnotation->second) { - auto arg = argList[argToCheck]; - llvm::SmallVector nonPackArgs; - if (arg.getKind() == clang::TemplateArgument::Pack) { - auto pack = arg.getPackAsArray(); - nonPackArgs.assign(pack.begin(), pack.end()); - } else - nonPackArgs.push_back(arg); - for (auto nonPackArg : nonPackArgs) { - - auto argValueSemantics = evaluateOrDefault( - evaluator, - CxxValueSemantics( - {nonPackArg.getAsType()->getUnqualifiedDesugaredType(), - desc.importerImpl}), - {}); - if (argValueSemantics != CxxValueSemanticsKind::Copyable) - return argValueSemantics; - } + auto injectedStlAnnotation = + recordDecl->isInStdNamespace() + ? STLConditionalParams.find(recordDecl->getName()) + : STLConditionalParams.end(); + auto STLParams = injectedStlAnnotation != STLConditionalParams.end() + ? injectedStlAnnotation->second + : std::vector(); + auto conditionalParams = getConditionalCopyableAttrParams(recordDecl); + + if (!STLParams.empty() || !conditionalParams.empty()) { + HeaderLoc loc{recordDecl->getLocation()}; + std::function checkArgValueSemantics = + [&](clang::TemplateArgument &arg, + StringRef argToCheck) -> std::optional { + if (arg.getKind() != clang::TemplateArgument::Type && importerImpl) { + importerImpl->diagnose(loc, diag::type_template_parameter_expected, + argToCheck); + return CxxValueSemanticsKind::Unknown; } - return CxxValueSemanticsKind::Copyable; - } + auto argValueSemantics = evaluateOrDefault( + evaluator, + CxxValueSemantics( + {arg.getAsType()->getUnqualifiedDesugaredType(), importerImpl}), + {}); + if (argValueSemantics != CxxValueSemanticsKind::Copyable) + return argValueSemantics; + return std::nullopt; + }; + + auto result = checkConditionalParams( + recordDecl, STLParams, conditionalParams, checkArgValueSemantics); + if (result.has_value()) + return result.value(); + + if (importerImpl) + for (auto name : conditionalParams) + importerImpl->diagnose(loc, diag::unknown_template_parameter, name); + + return CxxValueSemanticsKind::Copyable; } const auto cxxRecordDecl = dyn_cast(recordDecl); diff --git a/lib/ClangImporter/SwiftBridging/swift/bridging b/lib/ClangImporter/SwiftBridging/swift/bridging index 4c8e97d4a8e80..83ef4b1ed4523 100644 --- a/lib/ClangImporter/SwiftBridging/swift/bridging +++ b/lib/ClangImporter/SwiftBridging/swift/bridging @@ -190,6 +190,11 @@ __attribute__((swift_attr("~Copyable"))) \ __attribute__((swift_attr(_CXX_INTEROP_STRINGIFY(destroy:_destroy)))) +/// Specifies that a C++ `class` or `struct` should be imported as a copyable +/// Swift value if all of the specified template arguments are copyable. +#define SWIFT_COPYABLE_IF(...) \ + __attribute__((swift_attr("copyable_if:" _CXX_INTEROP_CONCAT(__VA_ARGS__)))) + /// Specifies that a specific class or struct should be imported /// as a non-escapable Swift value type. #define SWIFT_NONESCAPABLE \ @@ -283,6 +288,7 @@ #define SWIFT_UNCHECKED_SENDABLE #define SWIFT_NONCOPYABLE #define SWIDT_NONCOPYABLE_WITH_DESTROY(_destroy) +#define SWIFT_COPYABLE_IF(...) #define SWIFT_NONESCAPABLE #define SWIFT_ESCAPABLE #define SWIFT_ESCAPABLE_IF(...) diff --git a/test/Interop/Cxx/class/noncopyable-typechecker.swift b/test/Interop/Cxx/class/noncopyable-typechecker.swift index 81e6707204286..aa6e3ab81685b 100644 --- a/test/Interop/Cxx/class/noncopyable-typechecker.swift +++ b/test/Interop/Cxx/class/noncopyable-typechecker.swift @@ -1,7 +1,7 @@ // RUN: %empty-directory(%t) // RUN: split-file %s %t -// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %t/Inputs %t/test.swift -// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %t/Inputs %t/test.swift +// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift +// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift //--- Inputs/module.modulemap module Test { @@ -10,6 +10,7 @@ module Test { } //--- Inputs/noncopyable.h +#include "swift/bridging" #include struct NonCopyable { @@ -28,6 +29,29 @@ struct OwnsT { using OwnsNonCopyable = OwnsT; +template +struct SWIFT_COPYABLE_IF(T) AnnotatedOwnsT { + T element; + AnnotatedOwnsT() {} + AnnotatedOwnsT(const AnnotatedOwnsT &other) : element(other.element) {} + AnnotatedOwnsT(AnnotatedOwnsT&& other) {} +}; + +using AnnotatedOwnsNonCopyable = AnnotatedOwnsT; + +template +struct SWIFT_COPYABLE_IF(F, S) MyPair { + F first; + S second; +}; + +MyPair p1(); +MyPair p2(); +MyPair p3(); +MyPair p4(); +MyPair> p5(); +MyPair p6(); + #if __cplusplus >= 202002L template struct RequiresCopyableT { @@ -38,6 +62,9 @@ struct RequiresCopyableT { }; using NonCopyableRequires = RequiresCopyableT; +using CopyableIfRequires = RequiresCopyableT>; + +MyPair p7(); #endif @@ -55,9 +82,28 @@ func userDefinedTypes() { takeCopyable(ownsT) // no error, OwnsNonCopyable imported as Copyable } +func useCopyableIf() { + takeCopyable(p1()) // expected-error {{global function 'takeCopyable' requires that 'MyPair' conform to 'Copyable'}} + takeCopyable(p2()) + + // p3() -> MyPair is imported as Copyable and will cause an error during IRGen. + // During typecheck we don't produce an error because we're missing an annotation in OwnsT. + takeCopyable(p3()) + // p4() -> (MyPair) is imported as NonCopyable because AnnotatedOwnsT is correctly annotated. + takeCopyable(p4()) // expected-error {{global function 'takeCopyable' requires that 'MyPair>' conform to 'Copyable'}} + + takeCopyable(p5()) // expected-error {{global function 'takeCopyable' requires that 'MyPair>' conform to 'Copyable'}} + takeCopyable(p6()) // expected-error {{global function 'takeCopyable' requires that 'MyPair' conform to 'Copyable'}} +} + #if CPP20 func useOfRequires() { - let nCop = NonCopyableRequires() - takeCopyable(nCop) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT') conform to 'Copyable'}} + let a = NonCopyableRequires() + takeCopyable(a) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT') conform to 'Copyable'}} + + let b = CopyableIfRequires() + takeCopyable(b) // expected-cpp20-error {{global function 'takeCopyable' requires that 'CopyableIfRequires' (aka 'RequiresCopyableT>') conform to 'Copyable'}} + + takeCopyable(p7()) // expected-cpp20-error {{global function 'takeCopyable' requires that 'MyPair>' conform to 'Copyable'}} } #endif