diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index da4f40596e202..ffea06fd02040 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -105,6 +105,7 @@ PROTOCOL(DistributedTargetInvocationDecoder) PROTOCOL(DistributedTargetInvocationResultHandler) // C++ Standard Library Overlay: +PROTOCOL(CxxSequence) PROTOCOL(UnsafeCxxInputIterator) PROTOCOL(AsyncSequence) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 910f7b81363d3..4b443e7a1c671 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -1054,6 +1054,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const { case KnownProtocolKind::DistributedTargetInvocationResultHandler: M = getLoadedModule(Id_Distributed); break; + case KnownProtocolKind::CxxSequence: case KnownProtocolKind::UnsafeCxxInputIterator: M = getLoadedModule(Id_Cxx); break; diff --git a/lib/ClangImporter/ClangDerivedConformances.cpp b/lib/ClangImporter/ClangDerivedConformances.cpp index 185adebb0ce5a..d1fc7e0998cee 100644 --- a/lib/ClangImporter/ClangDerivedConformances.cpp +++ b/lib/ClangImporter/ClangDerivedConformances.cpp @@ -184,3 +184,80 @@ void swift::conformToCxxIteratorIfNeeded( impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::UnsafeCxxInputIterator}); } + +void swift::conformToCxxSequenceIfNeeded( + ClangImporter::Implementation &impl, NominalTypeDecl *decl, + const clang::CXXRecordDecl *clangDecl) { + PrettyStackTraceDecl trace("conforming to CxxSequence", decl); + + assert(decl); + assert(clangDecl); + ASTContext &ctx = decl->getASTContext(); + + ProtocolDecl *cxxIteratorProto = + ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator); + ProtocolDecl *cxxSequenceProto = + ctx.getProtocol(KnownProtocolKind::CxxSequence); + // If the Cxx module is missing, or does not include one of the necessary + // protocols, bail. + if (!cxxIteratorProto || !cxxSequenceProto) + return; + + // Check if present: `mutating func __beginUnsafe() -> RawIterator` + auto beginId = ctx.getIdentifier("__beginUnsafe"); + auto begins = lookupDirectWithoutExtensions(decl, beginId); + if (begins.size() != 1) + return; + auto begin = dyn_cast(begins.front()); + if (!begin) + return; + auto rawIteratorTy = begin->getResultInterfaceType(); + + // Check if present: `mutating func __endUnsafe() -> RawIterator` + auto endId = ctx.getIdentifier("__endUnsafe"); + auto ends = lookupDirectWithoutExtensions(decl, endId); + if (ends.size() != 1) + return; + auto end = dyn_cast(ends.front()); + if (!end) + return; + + // Check if `__beginUnsafe` and `__endUnsafe` have the same return type. + auto endTy = end->getResultInterfaceType(); + if (!endTy || endTy->getCanonicalType() != rawIteratorTy->getCanonicalType()) + return; + + // Check if RawIterator conforms to UnsafeCxxInputIterator. + auto rawIteratorConformanceRef = decl->getModuleContext()->lookupConformance( + rawIteratorTy, cxxIteratorProto); + if (!rawIteratorConformanceRef.isConcrete()) + return; + auto rawIteratorConformance = rawIteratorConformanceRef.getConcrete(); + auto pointeeDecl = + cxxIteratorProto->getAssociatedType(ctx.getIdentifier("Pointee")); + assert(pointeeDecl && + "UnsafeCxxInputIterator must have a Pointee associated type"); + auto pointeeTy = rawIteratorConformance->getTypeWitness(pointeeDecl); + assert(pointeeTy && "valid conformance must have a Pointee witness"); + + // Take the default definition of `Iterator` from CxxSequence protocol. This + // type is currently `CxxIterator`. + auto iteratorDecl = cxxSequenceProto->getAssociatedType(ctx.Id_Iterator); + auto iteratorTy = iteratorDecl->getDefaultDefinitionType(); + // Substitute generic `Self` parameter. + auto cxxSequenceSelfTy = cxxSequenceProto->getSelfInterfaceType(); + auto declSelfTy = decl->getDeclaredInterfaceType(); + iteratorTy = iteratorTy.subst( + [&](SubstitutableType *dependentType) { + if (dependentType->isEqual(cxxSequenceSelfTy)) + return declSelfTy; + return Type(dependentType); + }, + LookUpConformanceInModule(decl->getModuleContext())); + + impl.addSynthesizedTypealias(decl, ctx.Id_Element, pointeeTy); + impl.addSynthesizedTypealias(decl, ctx.Id_Iterator, iteratorTy); + impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"), + rawIteratorTy); + impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSequence}); +} diff --git a/lib/ClangImporter/ClangDerivedConformances.h b/lib/ClangImporter/ClangDerivedConformances.h index ae74eac07f3f1..f82719468cbd7 100644 --- a/lib/ClangImporter/ClangDerivedConformances.h +++ b/lib/ClangImporter/ClangDerivedConformances.h @@ -21,11 +21,17 @@ namespace swift { bool isIterator(const clang::CXXRecordDecl *clangDecl); /// If the decl is a C++ input iterator, synthesize a conformance to the -/// UnsafeCxxInputIterator protocol, which is defined in the std overlay. +/// UnsafeCxxInputIterator protocol, which is defined in the Cxx module. void conformToCxxIteratorIfNeeded(ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl); +/// If the decl is a C++ sequence, synthesize a conformance to the CxxSequence +/// protocol, which is defined in the Cxx module. +void conformToCxxSequenceIfNeeded(ClangImporter::Implementation &impl, + NominalTypeDecl *decl, + const clang::CXXRecordDecl *clangDecl); + } // namespace swift #endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp index f8a02d2edc472..a33dfa4258f82 100644 --- a/lib/ClangImporter/ImportDecl.cpp +++ b/lib/ClangImporter/ImportDecl.cpp @@ -2612,6 +2612,7 @@ namespace { if (clangModule && requiresCPlusPlus(clangModule)) { if (auto structDecl = dyn_cast_or_null(result)) { conformToCxxIteratorIfNeeded(Impl, structDecl, decl); + conformToCxxSequenceIfNeeded(Impl, structDecl, decl); } } diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index fe95035610289..0d1e4b5ed5973 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -5808,6 +5808,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::DistributedTargetInvocationEncoder: case KnownProtocolKind::DistributedTargetInvocationDecoder: case KnownProtocolKind::DistributedTargetInvocationResultHandler: + case KnownProtocolKind::CxxSequence: case KnownProtocolKind::UnsafeCxxInputIterator: case KnownProtocolKind::SerialExecutor: case KnownProtocolKind::Sendable: diff --git a/test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h b/test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h index a9a31b9366477..3b8b26ab0ef88 100644 --- a/test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h +++ b/test/Interop/Cxx/stdlib/overlay/Inputs/custom-sequence.h @@ -37,4 +37,53 @@ struct SimpleEmptySequence { const int *end() const { return nullptr; } }; +struct HasMutatingBeginEnd { + ConstIterator begin() { return ConstIterator(1); } + ConstIterator end() { return ConstIterator(5); } +}; + +// TODO: this should conform to CxxSequence. +struct __attribute__((swift_attr("import_reference"), + swift_attr("retain:immortal"), + swift_attr("release:immortal"))) ImmortalSequence { + ConstIterator begin() { return ConstIterator(1); } + ConstIterator end() { return ConstIterator(5); } +}; + +// MARK: Types that are not actually sequences + +struct HasNoBeginMethod { + ConstIterator end() const { return ConstIterator(1); } +}; + +struct HasNoEndMethod { + ConstIterator begin() const { return ConstIterator(1); } +}; + +struct HasBeginEndTypeMismatch { + ConstIterator begin() const { return ConstIterator(1); } + ConstIteratorOutOfLineEq end() const { return ConstIteratorOutOfLineEq(3); } +}; + +struct HasBeginEndReturnNonIterators { + struct NotIterator {}; + + NotIterator begin() const { return NotIterator(); } + NotIterator end() const { return NotIterator(); } +}; + +// TODO: this should not be conformed to CxxSequence, because +// `const ConstIterator &` is imported as `UnsafePointer`, and +// calling `successor()` is not actually going to call +// `ConstIterator::operator++()`. It will increment the address instead. +struct HasBeginEndReturnRef { +private: + ConstIterator b = ConstIterator(1); + ConstIterator e = ConstIterator(5); + +public: + const ConstIterator &begin() const { return b; } + const ConstIterator &end() const { return e; } +}; + #endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_SEQUENCE_H \ No newline at end of file diff --git a/test/Interop/Cxx/stdlib/overlay/custom-sequence-module-interface.swift b/test/Interop/Cxx/stdlib/overlay/custom-sequence-module-interface.swift new file mode 100644 index 0000000000000..85f2267bdf07b --- /dev/null +++ b/test/Interop/Cxx/stdlib/overlay/custom-sequence-module-interface.swift @@ -0,0 +1,60 @@ +// RUN: %target-swift-ide-test -print-module -module-to-print=CustomSequence -source-filename=x -I %S/Inputs -enable-experimental-cxx-interop -module-cache-path %t | %FileCheck %s + +// CHECK: import Cxx + +// CHECK: struct SimpleSequence : CxxSequence { +// CHECK: typealias Element = ConstIterator.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = ConstIterator +// CHECK: } + +// CHECK: struct SimpleSequenceWithOutOfLineEqualEqual : CxxSequence { +// CHECK: typealias Element = ConstIteratorOutOfLineEq.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = ConstIteratorOutOfLineEq +// CHECK: } + +// CHECK: struct SimpleArrayWrapper : CxxSequence { +// CHECK: typealias Element = UnsafePointer.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = UnsafePointer +// CHECK: } + +// CHECK: struct SimpleArrayWrapperNullableIterators : CxxSequence { +// CHECK: typealias Element = Optional>.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = UnsafePointer? +// CHECK: } + +// CHECK: struct SimpleEmptySequence : CxxSequence { +// CHECK: typealias Element = Optional>.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = UnsafePointer? +// CHECK: } + +// CHECK: struct HasMutatingBeginEnd : CxxSequence { +// CHECK: typealias Element = ConstIterator.Pointee +// CHECK: typealias Iterator = CxxIterator +// CHECK: typealias RawIterator = ConstIterator +// CHECK: } + +// CHECK: struct HasNoBeginMethod { +// CHECK-NOT: typealias Element +// CHECK-NOT: typealias Iterator +// CHECK-NOT: typealias RawIterator +// CHECK: } +// CHECK: struct HasNoEndMethod { +// CHECK-NOT: typealias Element +// CHECK-NOT: typealias Iterator +// CHECK-NOT: typealias RawIterator +// CHECK: } +// CHECK: struct HasBeginEndTypeMismatch { +// CHECK-NOT: typealias Element +// CHECK-NOT: typealias Iterator +// CHECK-NOT: typealias RawIterator +// CHECK: } +// CHECK: struct HasBeginEndReturnNonIterators { +// CHECK-NOT: typealias Element +// CHECK-NOT: typealias Iterator +// CHECK-NOT: typealias RawIterator +// CHECK: } diff --git a/test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift b/test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift index d7858274d7f30..cfd6ef48ebc47 100644 --- a/test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift +++ b/test/Interop/Cxx/stdlib/overlay/custom-sequence-typechecker.swift @@ -3,12 +3,7 @@ import CustomSequence import Cxx -// === SimpleSequence === -// Conformance to UnsafeCxxInputIterator is synthesized. -extension SimpleSequence: CxxSequence {} - -func checkSimpleSequence() { - let seq = SimpleSequence() +func checkIntSequence(_ seq: S) where S: Sequence, S.Element == Int32 { let contains = seq.contains(where: { $0 == 3 }) print(contains) @@ -17,17 +12,26 @@ func checkSimpleSequence() { } } +// === SimpleSequence === +// Conformance to UnsafeCxxInputIterator is synthesized. +// Conformance to CxxSequence is synthesized. +checkIntSequence(SimpleSequence()) + // === SimpleSequenceWithOutOfLineEqualEqual === -extension SimpleSequenceWithOutOfLineEqualEqual : CxxSequence {} +// Conformance to CxxSequence is synthesized. +checkIntSequence(SimpleSequenceWithOutOfLineEqualEqual()) // === SimpleArrayWrapper === // No UnsafeCxxInputIterator conformance required, since the iterators are actually UnsafePointers here. -extension SimpleArrayWrapper: CxxSequence {} +// Conformance to CxxSequence is synthesized. +checkIntSequence(SimpleArrayWrapper()) // === SimpleArrayWrapperNullableIterators === // No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here. -extension SimpleArrayWrapperNullableIterators: CxxSequence {} +// Conformance to CxxSequence is synthesized. +checkIntSequence(SimpleArrayWrapperNullableIterators()) // === SimpleEmptySequence === // No UnsafeCxxInputIterator conformance required, since the iterators are actually optional UnsafePointers here. -extension SimpleEmptySequence: CxxSequence {} +// Conformance to CxxSequence is synthesized. +checkIntSequence(SimpleEmptySequence()) diff --git a/test/Interop/Cxx/stdlib/overlay/custom-sequence.swift b/test/Interop/Cxx/stdlib/overlay/custom-sequence.swift index 98cf933d2b695..66d5700a23fb4 100644 --- a/test/Interop/Cxx/stdlib/overlay/custom-sequence.swift +++ b/test/Interop/Cxx/stdlib/overlay/custom-sequence.swift @@ -9,11 +9,6 @@ import Cxx var CxxSequenceTestSuite = TestSuite("CxxSequence") -extension SimpleSequence: CxxSequence {} - -extension SimpleEmptySequence: CxxSequence {} - - CxxSequenceTestSuite.test("SimpleSequence as Swift.Sequence") { let seq = SimpleSequence() let contains = seq.contains(where: { $0 == 3 })