Skip to content

Commit f4cf914

Browse files
committed
[cxx-interop] Add COPYABLE_IF macro
Refactor `ClangTypeEscapability::evaluate` so that we can reuse part of it for `CxxValueSemantics::evaluate`
1 parent e64241c commit f4cf914

File tree

4 files changed

+185
-86
lines changed

4 files changed

+185
-86
lines changed

include/swift/ClangImporter/ClangImporterRequests.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ SourceLoc extractNearestSourceLoc(EscapabilityLookupDescriptor desc);
576576
// the object’s storage. This means reference types can be imported as
577577
// copyable to Swift, even when they are non-copyable in C++.
578578
enum class CxxValueSemanticsKind {
579+
Unknown,
579580
Copyable,
580581
MoveOnly,
581582
// A record that is either not copyable/movable or not destructible.

lib/ClangImporter/ClangImporter.cpp

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5303,16 +5303,66 @@ static const llvm::StringMap<std::vector<int>> STLConditionalParams{
53035303
{"unordered_multimap", {0, 1}},
53045304
};
53055305

5306+
template <typename Kind>
5307+
static std::optional<Kind> checkConditionalParams(
5308+
clang::RecordDecl *recordDecl, const std::vector<int> &STLParams,
5309+
std::set<StringRef> &conditionalParams,
5310+
std::function<std::optional<Kind>(clang::TemplateArgument &, StringRef)>
5311+
&checkArg) {
5312+
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
5313+
SmallVector<std::pair<unsigned, StringRef>, 4> argumentsToCheck;
5314+
bool hasInjectedSTLAnnotation = !STLParams.empty();
5315+
while (specDecl) {
5316+
auto templateDecl = specDecl->getSpecializedTemplate();
5317+
if (hasInjectedSTLAnnotation) {
5318+
auto params = templateDecl->getTemplateParameters();
5319+
for (auto idx : STLParams)
5320+
argumentsToCheck.push_back(
5321+
std::make_pair(idx, params->getParam(idx)->getName()));
5322+
} else {
5323+
for (auto [idx, param] :
5324+
llvm::enumerate(*templateDecl->getTemplateParameters())) {
5325+
if (conditionalParams.erase(param->getName()))
5326+
argumentsToCheck.push_back(std::make_pair(idx, param->getName()));
5327+
}
5328+
}
5329+
auto &argList = specDecl->getTemplateArgs();
5330+
for (auto argToCheck : argumentsToCheck) {
5331+
auto arg = argList[argToCheck.first];
5332+
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
5333+
if (arg.getKind() == clang::TemplateArgument::Pack) {
5334+
auto pack = arg.getPackAsArray();
5335+
nonPackArgs.assign(pack.begin(), pack.end());
5336+
} else
5337+
nonPackArgs.push_back(arg);
5338+
for (auto nonPackArg : nonPackArgs) {
5339+
auto result = checkArg(nonPackArg, argToCheck.second);
5340+
if (result.has_value())
5341+
return result.value();
5342+
}
5343+
}
5344+
if (hasInjectedSTLAnnotation)
5345+
break;
5346+
clang::DeclContext *dc = specDecl;
5347+
specDecl = nullptr;
5348+
while ((dc = dc->getParent())) {
5349+
specDecl = dyn_cast<clang::ClassTemplateSpecializationDecl>(dc);
5350+
if (specDecl)
5351+
break;
5352+
}
5353+
}
5354+
return std::nullopt;
5355+
}
5356+
53065357
static std::set<StringRef>
5307-
getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
5358+
getConditionalAttrParams(const clang::RecordDecl *decl, StringRef attrName) {
53085359
std::set<StringRef> result;
53095360
if (!decl->hasAttrs())
53105361
return result;
53115362
for (auto attr : decl->getAttrs()) {
5312-
if (auto swiftAttr = dyn_cast<clang::SwiftAttrAttr>(attr))
5313-
if (swiftAttr->getAttribute().starts_with("escapable_if:")) {
5314-
StringRef params = swiftAttr->getAttribute().drop_front(
5315-
StringRef("escapable_if:").size());
5363+
if (auto swiftAttr = dyn_cast<clang::SwiftAttrAttr>(attr)) {
5364+
StringRef params = swiftAttr->getAttribute();
5365+
if (params.consume_front(attrName)) {
53165366
auto commaPos = params.find(',');
53175367
StringRef nextParam = params.take_front(commaPos);
53185368
while (!nextParam.empty() && commaPos != StringRef::npos) {
@@ -5322,10 +5372,21 @@ getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
53225372
nextParam = params.take_front(commaPos);
53235373
}
53245374
}
5375+
}
53255376
}
53265377
return result;
53275378
}
53285379

5380+
static std::set<StringRef>
5381+
getConditionalEscapableAttrParams(const clang::RecordDecl *decl) {
5382+
return getConditionalAttrParams(decl, "escapable_if:");
5383+
}
5384+
5385+
static std::set<StringRef>
5386+
getConditionalCopyableAttrParams(const clang::RecordDecl *decl) {
5387+
return getConditionalAttrParams(decl, "copyable_if:");
5388+
}
5389+
53295390
CxxEscapability
53305391
ClangTypeEscapability::evaluate(Evaluator &evaluator,
53315392
EscapabilityLookupDescriptor desc) const {
@@ -5351,60 +5412,33 @@ ClangTypeEscapability::evaluate(Evaluator &evaluator,
53515412
recordDecl->isInStdNamespace()
53525413
? STLConditionalParams.find(recordDecl->getName())
53535414
: STLConditionalParams.end();
5354-
bool hasInjectedSTLAnnotation =
5355-
injectedStlAnnotation != STLConditionalParams.end();
5415+
auto STLParams = injectedStlAnnotation != STLConditionalParams.end()
5416+
? injectedStlAnnotation->second
5417+
: std::vector<int>();
53565418
auto conditionalParams = getConditionalEscapableAttrParams(recordDecl);
5357-
if (!conditionalParams.empty() || hasInjectedSTLAnnotation) {
5358-
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
5359-
SmallVector<std::pair<unsigned, StringRef>, 4> argumentsToCheck;
5419+
5420+
if (!STLParams.empty() || !conditionalParams.empty()) {
53605421
HeaderLoc loc{recordDecl->getLocation()};
5361-
while (specDecl) {
5362-
auto templateDecl = specDecl->getSpecializedTemplate();
5363-
if (hasInjectedSTLAnnotation) {
5364-
auto params = templateDecl->getTemplateParameters();
5365-
for (auto idx : injectedStlAnnotation->second)
5366-
argumentsToCheck.push_back(
5367-
std::make_pair(idx, params->getParam(idx)->getName()));
5368-
} else {
5369-
for (auto [idx, param] :
5370-
llvm::enumerate(*templateDecl->getTemplateParameters())) {
5371-
if (conditionalParams.erase(param->getName()))
5372-
argumentsToCheck.push_back(std::make_pair(idx, param->getName()));
5373-
}
5422+
std::function checkArgEscapability =
5423+
[&](clang::TemplateArgument &arg,
5424+
StringRef argToCheck) -> std::optional<CxxEscapability> {
5425+
if (arg.getKind() != clang::TemplateArgument::Type && desc.impl) {
5426+
desc.impl->diagnose(loc, diag::type_template_parameter_expected,
5427+
argToCheck);
5428+
return CxxEscapability::Unknown;
53745429
}
5375-
auto &argList = specDecl->getTemplateArgs();
5376-
for (auto argToCheck : argumentsToCheck) {
5377-
auto arg = argList[argToCheck.first];
5378-
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
5379-
if (arg.getKind() == clang::TemplateArgument::Pack) {
5380-
auto pack = arg.getPackAsArray();
5381-
nonPackArgs.assign(pack.begin(), pack.end());
5382-
} else
5383-
nonPackArgs.push_back(arg);
5384-
for (auto nonPackArg : nonPackArgs) {
5385-
if (nonPackArg.getKind() != clang::TemplateArgument::Type &&
5386-
desc.impl) {
5387-
desc.impl->diagnose(loc, diag::type_template_parameter_expected,
5388-
argToCheck.second);
5389-
return CxxEscapability::Unknown;
5390-
}
53915430

5392-
auto argEscapability = evaluateEscapability(
5393-
nonPackArg.getAsType()->getUnqualifiedDesugaredType());
5394-
if (argEscapability == CxxEscapability::NonEscapable)
5395-
return CxxEscapability::NonEscapable;
5396-
}
5397-
}
5398-
if (hasInjectedSTLAnnotation)
5399-
break;
5400-
clang::DeclContext *dc = specDecl;
5401-
specDecl = nullptr;
5402-
while ((dc = dc->getParent())) {
5403-
specDecl = dyn_cast<clang::ClassTemplateSpecializationDecl>(dc);
5404-
if (specDecl)
5405-
break;
5406-
}
5407-
}
5431+
auto argEscapability = evaluateEscapability(
5432+
arg.getAsType()->getUnqualifiedDesugaredType());
5433+
if (argEscapability == CxxEscapability::NonEscapable)
5434+
return CxxEscapability::NonEscapable;
5435+
return std::nullopt;
5436+
};
5437+
5438+
auto result = checkConditionalParams<CxxEscapability>(
5439+
recordDecl, STLParams, conditionalParams, checkArgEscapability);
5440+
if (result.has_value())
5441+
return result.value();
54085442

54095443
if (desc.impl)
54105444
for (auto name : conditionalParams)
@@ -8343,36 +8377,48 @@ CxxValueSemantics::evaluate(Evaluator &evaluator,
83438377
if (recordDecl->getIdentifier() &&
83448378
recordDecl->getName() == "_Optional_construct_base")
83458379
return CxxValueSemanticsKind::Copyable;
8380+
}
83468381

8347-
auto injectedStlAnnotation =
8348-
STLConditionalParams.find(recordDecl->getName());
8349-
8350-
if (injectedStlAnnotation != STLConditionalParams.end()) {
8351-
auto specDecl = cast<clang::ClassTemplateSpecializationDecl>(recordDecl);
8352-
auto &argList = specDecl->getTemplateArgs();
8353-
for (auto argToCheck : injectedStlAnnotation->second) {
8354-
auto arg = argList[argToCheck];
8355-
llvm::SmallVector<clang::TemplateArgument, 1> nonPackArgs;
8356-
if (arg.getKind() == clang::TemplateArgument::Pack) {
8357-
auto pack = arg.getPackAsArray();
8358-
nonPackArgs.assign(pack.begin(), pack.end());
8359-
} else
8360-
nonPackArgs.push_back(arg);
8361-
for (auto nonPackArg : nonPackArgs) {
8362-
8363-
auto argValueSemantics = evaluateOrDefault(
8364-
evaluator,
8365-
CxxValueSemantics(
8366-
{nonPackArg.getAsType()->getUnqualifiedDesugaredType(),
8367-
desc.importerImpl}),
8368-
{});
8369-
if (argValueSemantics != CxxValueSemanticsKind::Copyable)
8370-
return argValueSemantics;
8371-
}
8382+
auto injectedStlAnnotation =
8383+
recordDecl->isInStdNamespace()
8384+
? STLConditionalParams.find(recordDecl->getName())
8385+
: STLConditionalParams.end();
8386+
auto STLParams = injectedStlAnnotation != STLConditionalParams.end()
8387+
? injectedStlAnnotation->second
8388+
: std::vector<int>();
8389+
auto conditionalParams = getConditionalCopyableAttrParams(recordDecl);
8390+
8391+
if (!STLParams.empty() || !conditionalParams.empty()) {
8392+
HeaderLoc loc{recordDecl->getLocation()};
8393+
std::function checkArgValueSemantics =
8394+
[&](clang::TemplateArgument &arg,
8395+
StringRef argToCheck) -> std::optional<CxxValueSemanticsKind> {
8396+
if (arg.getKind() != clang::TemplateArgument::Type && importerImpl) {
8397+
importerImpl->diagnose(loc, diag::type_template_parameter_expected,
8398+
argToCheck);
8399+
return CxxValueSemanticsKind::Unknown;
83728400
}
83738401

8374-
return CxxValueSemanticsKind::Copyable;
8375-
}
8402+
auto argValueSemantics = evaluateOrDefault(
8403+
evaluator,
8404+
CxxValueSemantics(
8405+
{arg.getAsType()->getUnqualifiedDesugaredType(), importerImpl}),
8406+
{});
8407+
if (argValueSemantics != CxxValueSemanticsKind::Copyable)
8408+
return argValueSemantics;
8409+
return std::nullopt;
8410+
};
8411+
8412+
auto result = checkConditionalParams<CxxValueSemanticsKind>(
8413+
recordDecl, STLParams, conditionalParams, checkArgValueSemantics);
8414+
if (result.has_value())
8415+
return result.value();
8416+
8417+
if (importerImpl)
8418+
for (auto name : conditionalParams)
8419+
importerImpl->diagnose(loc, diag::unknown_template_parameter, name);
8420+
8421+
return CxxValueSemanticsKind::Copyable;
83768422
}
83778423

83788424
const auto cxxRecordDecl = dyn_cast<clang::CXXRecordDecl>(recordDecl);

lib/ClangImporter/SwiftBridging/swift/bridging

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@
190190
__attribute__((swift_attr("~Copyable"))) \
191191
__attribute__((swift_attr(_CXX_INTEROP_STRINGIFY(destroy:_destroy))))
192192

193+
/// Specifies that a C++ `class` or `struct` should be imported as a copyable
194+
/// Swift value if all of the specified template arguments are copyable.
195+
#define SWIFT_COPYABLE_IF(...) \
196+
__attribute__((swift_attr("copyable_if:" _CXX_INTEROP_CONCAT(__VA_ARGS__))))
197+
193198
/// Specifies that a specific class or struct should be imported
194199
/// as a non-escapable Swift value type.
195200
#define SWIFT_NONESCAPABLE \
@@ -283,6 +288,7 @@
283288
#define SWIFT_UNCHECKED_SENDABLE
284289
#define SWIFT_NONCOPYABLE
285290
#define SWIDT_NONCOPYABLE_WITH_DESTROY(_destroy)
291+
#define SWIFT_COPYABLE_IF(...)
286292
#define SWIFT_NONESCAPABLE
287293
#define SWIFT_ESCAPABLE
288294
#define SWIFT_ESCAPABLE_IF(...)

test/Interop/Cxx/class/noncopyable-typechecker.swift

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: %empty-directory(%t)
22
// RUN: split-file %s %t
3-
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %t/Inputs %t/test.swift
4-
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %t/Inputs %t/test.swift
3+
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift
4+
// RUN: %target-swift-frontend -cxx-interoperability-mode=default -Xcc -std=c++20 -verify-additional-prefix cpp20- -D CPP20 -typecheck -verify -I %swift_src_root/lib/ClangImporter/SwiftBridging -I %t/Inputs %t/test.swift
55

66
//--- Inputs/module.modulemap
77
module Test {
@@ -10,6 +10,7 @@ module Test {
1010
}
1111

1212
//--- Inputs/noncopyable.h
13+
#include "swift/bridging"
1314
#include <string>
1415

1516
struct NonCopyable {
@@ -28,6 +29,29 @@ struct OwnsT {
2829

2930
using OwnsNonCopyable = OwnsT<NonCopyable>;
3031

32+
template <typename T>
33+
struct SWIFT_COPYABLE_IF(T) AnnotatedOwnsT {
34+
T element;
35+
AnnotatedOwnsT() {}
36+
AnnotatedOwnsT(const AnnotatedOwnsT &other) : element(other.element) {}
37+
AnnotatedOwnsT(AnnotatedOwnsT&& other) {}
38+
};
39+
40+
using AnnotatedOwnsNonCopyable = AnnotatedOwnsT<NonCopyable>;
41+
42+
template <typename F, typename S>
43+
struct SWIFT_COPYABLE_IF(F, S) MyPair {
44+
F first;
45+
S second;
46+
};
47+
48+
MyPair<int, NonCopyable> p1();
49+
MyPair<int, NonCopyable*> p2();
50+
MyPair<int, OwnsNonCopyable> p3();
51+
MyPair<int, AnnotatedOwnsNonCopyable> p4();
52+
MyPair<int, MyPair<int, NonCopyable>> p5();
53+
MyPair<NonCopyable, int> p6();
54+
3155
#if __cplusplus >= 202002L
3256
template <typename T>
3357
struct RequiresCopyableT {
@@ -38,6 +62,9 @@ struct RequiresCopyableT {
3862
};
3963

4064
using NonCopyableRequires = RequiresCopyableT<NonCopyable>;
65+
using CopyableIfRequires = RequiresCopyableT<MyPair<int, NonCopyable>>;
66+
67+
MyPair<int, NonCopyableRequires> p7();
4168

4269
#endif
4370

@@ -55,9 +82,28 @@ func userDefinedTypes() {
5582
takeCopyable(ownsT) // no error, OwnsNonCopyable imported as Copyable
5683
}
5784

85+
func useCopyableIf() {
86+
takeCopyable(p1()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, NonCopyable>' conform to 'Copyable'}}
87+
takeCopyable(p2())
88+
89+
// p3() -> MyPair<int, OwnsNonCopyable> is imported as Copyable and will cause an error during IRGen.
90+
// During typecheck we don't produce an error because we're missing an annotation in OwnsT.
91+
takeCopyable(p3())
92+
// p4() -> (MyPair<int, AnnotatedOwnsNonCopyable>) is imported as NonCopyable because AnnotatedOwnsT is correctly annotated.
93+
takeCopyable(p4()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, AnnotatedOwnsT<NonCopyable>>' conform to 'Copyable'}}
94+
95+
takeCopyable(p5()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<CInt, MyPair<CInt, NonCopyable>>' conform to 'Copyable'}}
96+
takeCopyable(p6()) // expected-error {{global function 'takeCopyable' requires that 'MyPair<NonCopyable, CInt>' conform to 'Copyable'}}
97+
}
98+
5899
#if CPP20
59100
func useOfRequires() {
60-
let nCop = NonCopyableRequires()
61-
takeCopyable(nCop) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT<NonCopyable>') conform to 'Copyable'}}
101+
let a = NonCopyableRequires()
102+
takeCopyable(a) // expected-cpp20-error {{global function 'takeCopyable' requires that 'NonCopyableRequires' (aka 'RequiresCopyableT<NonCopyable>') conform to 'Copyable'}}
103+
104+
let b = CopyableIfRequires()
105+
takeCopyable(b) // expected-cpp20-error {{global function 'takeCopyable' requires that 'CopyableIfRequires' (aka 'RequiresCopyableT<MyPair<CInt, NonCopyable>>') conform to 'Copyable'}}
106+
107+
takeCopyable(p7()) // expected-cpp20-error {{global function 'takeCopyable' requires that 'MyPair<CInt, RequiresCopyableT<NonCopyable>>' conform to 'Copyable'}}
62108
}
63109
#endif

0 commit comments

Comments
 (0)