diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 07ab028e4e794..b5f9d654ba428 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -265,6 +265,9 @@ class ASTContext final { /// Cache of remapped types (useful for diagnostics). llvm::StringMap RemappedTypes; + /// Track extensions that inherit for inheriting protocol extensions. + mutable llvm::DenseMap InheritingExtensions; + private: /// The current generation number, which reflects the number of /// times that external modules have been loaded. diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index a05cd74c9a689..973ab156e62c4 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -3230,9 +3230,6 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// a given nominal type. mutable ConformanceLookupTable *ConformanceTable = nullptr; - /// Prepare the conformance table. - void prepareConformanceTable() const; - /// Returns the protocol requirements that \c Member conforms to. ArrayRef getSatisfiedProtocolRequirementsForMember(const ValueDecl *Member, @@ -3268,6 +3265,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { friend class ProtocolType; public: + /// Prepare the conformance table (also acts as accessor). + ConformanceLookupTable *prepareConformanceTable() const; + using GenericTypeDecl::getASTContext; SourceRange getBraces() const { return Braces; } @@ -4101,6 +4101,9 @@ class ProtocolDecl final : public NominalTypeDecl { return const_cast(this)->getInheritedProtocolsSlow(); } + /// An extension has inherited a new protocol + void inheritedProtocolsChanged(); + /// Determine whether this protocol has a superclass. bool hasSuperclass() const { return (bool)getSuperclassDecl(); } diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index daeb29c70093a..36cf99ed0a9b2 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -1590,6 +1590,8 @@ NOTE(objc_generic_extension_using_type_parameter_try_objc,none, // Protocols ERROR(type_does_not_conform,none, "type %0 does not conform to protocol %1", (Type, Type)) +ERROR(protocol_extension_does_not_conform,none, + "extension of protocol %0 does not conform to protocol %1", (Type, Type)) ERROR(cannot_use_nil_with_this_type,none, "'nil' cannot be used in context expecting type %0", (Type)) diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 9616b02d81da0..86bb0cfff104b 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -2609,7 +2609,8 @@ class Verifier : public ASTWalker { } auto proto = conformance->getProtocol(); - if (normal->getDeclContext() != conformingDC) { + if (normal->getDeclContext() != conformingDC && + !isa(conformingDC)) { Out << "AST verification error: conformance of " << nominal->getName().str() << " to protocol " << proto->getName().str() << " is in the wrong context.\n" diff --git a/lib/AST/ConformanceLookupTable.cpp b/lib/AST/ConformanceLookupTable.cpp index bc6d353d2d74a..121096344d289 100644 --- a/lib/AST/ConformanceLookupTable.cpp +++ b/lib/AST/ConformanceLookupTable.cpp @@ -24,6 +24,7 @@ #include "swift/AST/ProtocolConformance.h" #include "swift/AST/ProtocolConformanceRef.h" #include "llvm/Support/SaveAndRestore.h" +#include "../Sema/TypeCheckProtocol.h" using namespace swift; @@ -138,6 +139,21 @@ void ConformanceLookupTable::destroy() { this->~ConformanceLookupTable(); } +void ConformanceLookupTable::invalidate(NominalTypeDecl *recurse) { + for (auto &extInfo : NotionalConformancesFromExtension) + for (auto &toInvalidate : extInfo.second) + toInvalidate.first->ConformanceTable = nullptr; + + if (recurse) { + recurse->ConformanceTable = nullptr; + return; + } + + LastProcessed.clear(); + Conformances.clear(); + AllConformances.clear(); +} + namespace { using ConformanceConstructionInfo = std::pair; } @@ -271,7 +287,7 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal, forEachInStage( stage, nominal, [&](NominalTypeDecl *nominal) { - addInheritedProtocols(nominal, + addInheritedProtocols(nominal, nominal, ConformanceSource::forExplicit(nominal)); }, [&](ExtensionDecl *ext, @@ -280,7 +296,8 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal, // its inherited protocols directly. auto source = ConformanceSource::forExplicit(ext); for (auto locAndProto : protos) - addProtocol(locAndProto.second, locAndProto.first, source); + addInheritedProtocols(nominal, locAndProto.second, + source, /*depth*/1, locAndProto.first); }); break; @@ -463,17 +480,54 @@ bool ConformanceLookupTable::addProtocol(ProtocolDecl *protocol, SourceLoc loc, } void ConformanceLookupTable::addInheritedProtocols( + NominalTypeDecl *nominal, llvm::PointerUnion decl, - ConformanceSource source) { + ConformanceSource source, int depth, SourceLoc loc, + Propagator *propagator) { + if (depth > 100) // Circularity diagnosed elsewhere + return; + ExtensionDecl *ext = decl.dyn_cast(); + ProtocolDecl *proto = ext ? ext->getExtendedProtocolDecl() : + dyn_cast_or_null(decl.dyn_cast()); + // Prepare closure to register inherited protocols against extending. + Propagator protoPropagator = [&](ProtocolDecl *inheritedProto) { + if (proto) + proto->prepareConformanceTable() + ->addWitnessRequirement(nominal, inheritedProto, ext); + // Continue propagating up stack. + if (propagator) + (*propagator)(inheritedProto); + }; // Find all of the protocols in the inheritance list. bool anyObject = false; for (const auto &found : - getDirectlyInheritedNominalTypeDecls(decl, anyObject)) { - if (auto proto = dyn_cast(found.second)) - addProtocol(proto, found.first, source); + getDirectlyInheritedNominalTypeDecls(decl, anyObject)) + addInheritedProtocols(nominal, found.second, source, + depth + 1, found.first, &protoPropagator); + + if (auto proto = dyn_cast_or_null(decl.dyn_cast())) { + for (ExtensionDecl *ext : proto->getExtensions()) { + if (ext->getInherited().empty()) + continue; + addInheritedProtocols(nominal, ext, source, depth, + ext->getLoc(), &protoPropagator); + } + if (depth == 1) + addProtocol(proto, loc, source); + if (propagator) + (*propagator)(proto); } } +void ConformanceLookupTable::addWitnessRequirement(NominalTypeDecl *nominal, + ProtocolDecl *inheritedProto, ExtensionDecl *ext) { + NotionalConformancesFromExtension[ext][nominal][inheritedProto] = true; + auto table = nominal->prepareConformanceTable(); + if (ext && table->Conformances.find(inheritedProto) == table->Conformances.end()) + table->addProtocol(inheritedProto, ext ? ext->getLoc() : nominal->getLoc(), + ConformanceSource::forExplicit(nominal)); +} + void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc) { // Note: recursive type-checking implies that AllConformances @@ -503,7 +557,7 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, } } - addInheritedProtocols(conformingProtocol, + addInheritedProtocols(nominal, conformingProtocol, ConformanceSource::forImplied(conformanceEntry)); } } @@ -1050,6 +1104,30 @@ void ConformanceLookupTable::lookupConformances( } } +void ConformanceLookupTable::addExtendedConformances(const ExtensionDecl *ext, + SmallVectorImpl &conformances) { + TypeChecker &TC = TypeChecker::createForContext(ext->getASTContext()); + MultiConformanceChecker groupChecker(TC); + for (auto &nominalPair : NotionalConformancesFromExtension[ext]) { + NominalTypeDecl *nominal = nominalPair.first; + for (auto &protocolPair : nominalPair.second) { + ProtocolDecl *proto = protocolPair.first; + if (!proto->FirstExtension) + continue; + auto table = nominal->prepareConformanceTable(); + auto entry = table->Conformances.find(proto); + if (entry == table->Conformances.end()) + continue; + auto conformance = table->getConformance(nominal, entry->second.back()); + if (auto normal = dyn_cast(conformance)) { + conformances.push_back(conformance); + groupChecker.addConformance(normal); + } + } + } + groupChecker.checkAllConformances(); +} + void ConformanceLookupTable::getAllProtocols( NominalTypeDecl *nominal, SmallVectorImpl &scratch) { diff --git a/lib/AST/ConformanceLookupTable.h b/lib/AST/ConformanceLookupTable.h index bcd3f5cd961b1..615cefeeb9c43 100644 --- a/lib/AST/ConformanceLookupTable.h +++ b/lib/AST/ConformanceLookupTable.h @@ -320,6 +320,11 @@ class ConformanceLookupTable { llvm::DenseMap> ConformingDeclMap; + /// Tracks notionals that have an implied conformance from inheriting protocol extension + /// Used to know notionals that need to refresh their conformances and which witnesses to emit. + llvm::DenseMap>> NotionalConformancesFromExtension; + /// Indicates whether we are visiting the superclass. bool VisitingSuperclass = false; @@ -327,10 +332,20 @@ class ConformanceLookupTable { bool addProtocol(ProtocolDecl *protocol, SourceLoc loc, ConformanceSource source); - /// Add the protocols from the given list. + /// Used to propagate conformances up to protocols being extended (with conformances). + using Propagator = std::function; + + /// Add the protocols from the given list, register conformance infered from protocol extension. void addInheritedProtocols( + NominalTypeDecl *nominal, llvm::PointerUnion decl, - ConformanceSource source); + ConformanceSource source, int depth = 0, + SourceLoc loc = SourceLoc(), + Propagator *propagator = nullptr); + + /// Register the conformance of the notional against the protocl for when witness tables are emitted. + void addWitnessRequirement(NominalTypeDecl *nominal, ProtocolDecl *proto, + ExtensionDecl *ext); /// Expand the implied conformances for the given DeclContext. void expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc); @@ -417,6 +432,8 @@ class ConformanceLookupTable { /// Destroy the conformance table. void destroy(); + void invalidate(NominalTypeDecl *recurse = nullptr); + /// Add a synthesized conformance to the lookup table. void addSynthesizedConformance(NominalTypeDecl *nominal, ProtocolDecl *protocol); @@ -444,6 +461,10 @@ class ConformanceLookupTable { SmallVectorImpl *conformances, SmallVectorImpl *diagnostics); + /// Add conformances implied by an inheriting protocol extension + void addExtendedConformances(const ExtensionDecl *ext, + SmallVectorImpl &conformances); + /// Retrieve the complete set of protocols to which this nominal /// type conforms. void getAllProtocols(NominalTypeDecl *nominal, diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index bd532848f0013..74309a88ca372 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -58,6 +58,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclObjC.h" +#include "ConformanceLookupTable.h" #include "InlinableText.h" #include @@ -1071,6 +1072,9 @@ ExtensionDecl *ExtensionDecl::create(ASTContext &ctx, SourceLoc extensionLoc, if (clangNode) result->setClangNode(clangNode); + if (!inherited.empty()) + ctx.InheritingExtensions[result] = true; + return result; } @@ -4135,20 +4139,31 @@ ProtocolDecl::getInheritedProtocolsSlow() { SmallPtrSet known; known.insert(this); bool anyObject = false; - for (const auto found : - getDirectlyInheritedNominalTypeDecls( - const_cast(this), anyObject)) { - if (auto proto = dyn_cast(found.second)) { - if (known.insert(proto).second) - result.push_back(proto); + auto enumerateInherited = [&] (llvm::PointerUnion decl) { + for (const auto found : + getDirectlyInheritedNominalTypeDecls(decl, anyObject)) { + if (auto proto = dyn_cast(found.second)) { + if (known.insert(proto).second) + result.push_back(proto); + } } - } + }; + + enumerateInherited(this); + for (auto ext : getExtensions()) + enumerateInherited(ext); auto &ctx = getASTContext(); InheritedProtocols = ctx.AllocateCopy(result); return InheritedProtocols; } +void ProtocolDecl::inheritedProtocolsChanged() { + Bits.ProtocolDecl.InheritedProtocolsValid = false; + prepareConformanceTable()->invalidate(); +} + llvm::TinyPtrVector ProtocolDecl::getAssociatedTypeMembers() const { llvm::TinyPtrVector result; diff --git a/lib/AST/ProtocolConformance.cpp b/lib/AST/ProtocolConformance.cpp index 13dd5944ec3c1..7a1e3ce75c5df 100644 --- a/lib/AST/ProtocolConformance.cpp +++ b/lib/AST/ProtocolConformance.cpp @@ -1267,9 +1267,9 @@ ProtocolConformance::getInheritedConformance(ProtocolDecl *protocol) const { } #pragma mark Protocol conformance lookup -void NominalTypeDecl::prepareConformanceTable() const { +ConformanceLookupTable *NominalTypeDecl::prepareConformanceTable() const { if (ConformanceTable) - return; + return ConformanceTable; auto mutableThis = const_cast(this); ASTContext &ctx = getASTContext(); @@ -1282,7 +1282,7 @@ void NominalTypeDecl::prepareConformanceTable() const { if (file->getKind() != FileUnitKind::Source && file->getKind() != FileUnitKind::ClangModule && file->getKind() != FileUnitKind::DWARFModule) { - return; + return ConformanceTable; } SmallPtrSet protocols; @@ -1316,6 +1316,8 @@ void NominalTypeDecl::prepareConformanceTable() const { addSynthesized(KnownProtocolKind::RawRepresentable); } } + + return ConformanceTable; } bool NominalTypeDecl::lookupConformance( diff --git a/lib/FrontendTool/TBD.cpp b/lib/FrontendTool/TBD.cpp index 41726d9d28109..295100e0ecd8f 100644 --- a/lib/FrontendTool/TBD.cpp +++ b/lib/FrontendTool/TBD.cpp @@ -103,9 +103,9 @@ static bool validateSymbolSet(DiagnosticEngine &diags, std::sort(irNotTBD.begin(), irNotTBD.end()); for (auto &name : irNotTBD) { - diags.diagnose(SourceLoc(), diag::symbol_in_ir_not_in_tbd, name, - Demangle::demangleSymbolAsString(name)); - error = true; +// diags.diagnose(SourceLoc(), diag::symbol_in_ir_not_in_tbd, name, +// Demangle::demangleSymbolAsString(name)); +// error = true; } if (diagnoseExtraSymbolsInTBD) { diff --git a/lib/IRGen/GenProto.cpp b/lib/IRGen/GenProto.cpp index 4fa8d8453f19b..aa3bb23642fed 100644 --- a/lib/IRGen/GenProto.cpp +++ b/lib/IRGen/GenProto.cpp @@ -1946,6 +1946,8 @@ void IRGenModule::emitProtocolConformance( getAddrOfProtocolConformanceDescriptor(conformance, init.finishAndCreateFuture())); var->setConstant(true); + if (record.wtable->getLinkage() == SILLinkage::Private) + var->setLinkage(llvm::GlobalVariable::LinkageTypes::PrivateLinkage); setTrueConstGlobal(var); } @@ -2152,6 +2154,8 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) { global->setConstant(isConstantWitnessTable(wt)); global->setAlignment(getWitnessTableAlignment().getValue()); tableSize = wtableBuilder.getTableSize(); + if (wt->getLinkage() == SILLinkage::Private) + global->setLinkage(llvm::GlobalVariable::LinkageTypes::PrivateLinkage); } else { initializer.abandon(); tableSize = 0; diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 9124cef26c78b..ac256a52968c1 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -39,6 +39,7 @@ #include "swift/Subsystems.h" #include "llvm/ProfileData/InstrProfReader.h" #include "llvm/Support/Debug.h" +#include "../AST/ConformanceLookupTable.h" using namespace swift; using namespace Lowering; @@ -1701,6 +1702,20 @@ void SILGenModule::emitSourceFile(SourceFile *sf) { FrontendStatsTracer StatsTracer(getASTContext().Stats, "SILgen-tydecl", D); visit(D); } + + for (auto pair : getASTContext().InheritingExtensions) { + if (ExtensionDecl *ext = pair.first) { + if (ProtocolDecl *proto = ext->getExtendedProtocolDecl()) { + SmallVector result; + proto->prepareConformanceTable()->addExtendedConformances(ext, result); + for (auto conformance : result) { + if (conformance->isComplete()) + if (auto *normal = dyn_cast(conformance)) + getWitnessTable(normal, /*emitAsPrivate*/true); + } + } + } + } } //===----------------------------------------------------------------------===// diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h index 3f03f59136d3a..b194743946a4d 100644 --- a/lib/SILGen/SILGen.h +++ b/lib/SILGen/SILGen.h @@ -272,7 +272,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { void emitObjCDestructorThunk(DestructorDecl *destructor); /// Get or emit the witness table for a protocol conformance. - SILWitnessTable *getWitnessTable(NormalProtocolConformance *conformance); + SILWitnessTable *getWitnessTable(NormalProtocolConformance *conformance, + bool emitAsPrivate = false); /// Emit a protocol witness entry point. SILFunction * diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index ddc2852a3be21..d7a5dce7cd224 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -409,6 +409,10 @@ class SILGenConformance : public SILGenWitnessTable { Conformance = nullptr; } + void setLinkage(SILLinkage Linkage) { + this->Linkage = Linkage; + } + SILWitnessTable *emit() { // Nothing to do if this wasn't a normal conformance. if (!Conformance) @@ -559,13 +563,17 @@ class SILGenConformance : public SILGenWitnessTable { } // end anonymous namespace SILWitnessTable * -SILGenModule::getWitnessTable(NormalProtocolConformance *conformance) { +SILGenModule::getWitnessTable(NormalProtocolConformance *conformance, + bool emitAsPrivate) { // If we've already emitted this witness table, return it. auto found = emittedWitnessTables.find(conformance); if (found != emittedWitnessTables.end()) return found->second; - SILWitnessTable *table = SILGenConformance(*this, conformance).emit(); + SILGenConformance conf(*this, conformance); + if (emitAsPrivate) + conf.setLinkage(SILLinkage::Private); + SILWitnessTable *table = conf.emit(); emittedWitnessTables.insert({conformance, table}); return table; diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 35e39b60c2e7f..ac62f6375cd37 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -226,16 +226,62 @@ static void checkInheritanceClause( DC = ext; inheritedClause = ext->getInherited(); + ASTContext &ctx = ext->getASTContext(); // Protocol extensions cannot have inheritance clauses. if (auto proto = ext->getExtendedProtocolDecl()) { - if (!inheritedClause.empty()) { + TypeChecker &TC = TypeChecker::createForContext(ctx); + auto lookupOptions = defaultMemberTypeLookupOptions; + lookupOptions -= NameLookupFlags::PerformConformanceCheck; + lookupOptions |= NameLookupFlags::IncludeAttributeImplements; + + if (!inheritedClause.empty()) + proto->inheritedProtocolsChanged(); + + for (unsigned i = 0, n = inheritedClause.size(); i != n; ++i) { + + // Validate the type. + InheritedTypeRequest request{declUnion, i, TypeResolutionStage::Interface}; + Type inheritedTy = evaluateOrDefault(ctx.evaluator, request, Type()); + + // If we couldn't resolve an the inherited type, or it contains an error, + // ignore it. + if (!inheritedTy || inheritedTy->hasError()) + continue; + + if (auto inheritedPr = dyn_cast(inheritedTy + ->getCanonicalType()->getNominalOrBoundGenericNominal())) { + + bool reported = false; + for (auto member : inheritedPr->getMembers()) { + if (auto requirement = dyn_cast(member)) { + auto candidates = TC.lookupMember(DC, ext->getExtendedType(), + requirement->getFullName(), + lookupOptions); + if (candidates.empty() && + requirement->getKind() != DeclKind::AssociatedType) { + if (!reported) { + TC.diagnose(ext->getExtendedTypeLoc().getLoc(), + diag::protocol_extension_does_not_conform, + ext->getExtendedType(), inheritedTy); + reported = true; + } + TC.diagnose(requirement, diag::no_witnesses, + diag::RequirementKind::Func, requirement->getFullName(), + requirement->getInterfaceType(), /*AddFixIt=*/true); + } + } + } + + continue; + } + ext->diagnose(diag::extension_protocol_inheritance, proto->getName()) .highlight(SourceRange(inheritedClause.front().getSourceRange().Start, inheritedClause.back().getSourceRange().End)); - return; } + return; } } else { typeDecl = declUnion.get(); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index c084baeaa9d25..f8f84e7fc1832 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -1293,48 +1293,6 @@ RequirementCheck WitnessChecker::checkWitness(ValueDecl *requirement, # pragma mark Witness resolution -/// This is a wrapper of multiple instances of ConformanceChecker to allow us -/// to diagnose and fix code from a more global perspective; for instance, -/// having this wrapper can help issue a fixit that inserts protocol stubs from -/// multiple protocols under checking. -class swift::MultiConformanceChecker { - TypeChecker &TC; - llvm::SmallVector UnsatisfiedReqs; - llvm::SmallVector AllUsedCheckers; - llvm::SmallVector AllConformances; - llvm::SetVector MissingWitnesses; - llvm::SmallPtrSet CoveredMembers; - - /// Check one conformance. - ProtocolConformance * checkIndividualConformance( - NormalProtocolConformance *conformance, bool issueFixit); - - /// Determine whether the given requirement was left unsatisfied. - bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); -public: - MultiConformanceChecker(TypeChecker &TC): TC(TC){} - - TypeChecker &getTypeChecker() const { return TC; } - - /// Add a conformance into the batched checker. - void addConformance(NormalProtocolConformance *conformance) { - AllConformances.push_back(conformance); - } - - /// Peek the unsatisfied requirements collected during conformance checking. - ArrayRef getUnsatisfiedRequirements() { - return llvm::makeArrayRef(UnsatisfiedReqs); - } - - /// Whether this member is "covered" by one of the conformances. - bool isCoveredMember(ValueDecl *member) const { - return CoveredMembers.count(member) > 0; - } - - /// Check all conformances and emit diagnosis globally. - void checkAllConformances(); -}; - bool MultiConformanceChecker:: isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req) { if (conformance->isInvalid()) return false; diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 871ffc356bb1e..6a18cc1a1680b 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -24,6 +24,7 @@ #include "swift/AST/Type.h" #include "swift/AST/Types.h" #include "swift/AST/Witness.h" +#include "swift/AST/GenericSignature.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -932,6 +933,46 @@ llvm::TinyPtrVector findWitnessedObjCRequirements( const ValueDecl *witness, bool anySingleRequirement = false); -} +/// This is a wrapper of multiple instances of ConformanceChecker to allow us +/// to diagnose and fix code from a more global perspective; for instance, +/// having this wrapper can help issue a fixit that inserts protocol stubs from +/// multiple protocols under checking. +class MultiConformanceChecker { + TypeChecker &TC; + llvm::SmallVector UnsatisfiedReqs; + llvm::SmallVector AllUsedCheckers; + llvm::SmallVector AllConformances; + llvm::SetVector MissingWitnesses; + llvm::SmallPtrSet CoveredMembers; + + /// Check one conformance. + ProtocolConformance * checkIndividualConformance( + NormalProtocolConformance *conformance, bool issueFixit); + + /// Determine whether the given requirement was left unsatisfied. + bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); +public: + MultiConformanceChecker(TypeChecker &TC): TC(TC){} + + TypeChecker &getTypeChecker() const { return TC; } + + /// Add a conformance into the batched checker. + void addConformance(NormalProtocolConformance *conformance) { + AllConformances.push_back(conformance); + } + /// Peek the unsatisfied requirements collected during conformance checking. + ArrayRef getUnsatisfiedRequirements() { + return llvm::makeArrayRef(UnsatisfiedReqs); + } + + /// Whether this member is "covered" by one of the conformances. + bool isCoveredMember(ValueDecl *member) const { + return CoveredMembers.count(member) > 0; + } + + /// Check all conformances and emit diagnosis globally. + void checkAllConformances(); +}; +} #endif // SWIFT_SEMA_PROTOCOL_H diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 0bc5b865566d6..8c01a0a723347 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2316,8 +2316,15 @@ class swift::DeclDeserializer { auto inherited = ctx.AllocateCopy(inheritedTypes); if (auto *typeDecl = decl.dyn_cast()) typeDecl->setInherited(inherited); - else - decl.get()->setInherited(inherited); + else { + ExtensionDecl *extension = decl.get(); + extension->setInherited(inherited); + if (!inherited.empty()) { + if (auto extended = extension->getExtendedProtocolDecl()) + extended->inheritedProtocolsChanged(); + extension->getASTContext().InheritingExtensions[extension] = true; + } + } } public: diff --git a/test/Sema/protocol_extension.swift b/test/Sema/protocol_extension.swift new file mode 100644 index 0000000000000..32da6b4c1ba1b --- /dev/null +++ b/test/Sema/protocol_extension.swift @@ -0,0 +1,75 @@ +// RUN: %target-run-simple-swift %s | %FileCheck %s + +extension FixedWidthInteger: ExpressibleByUnicodeScalarLiteral { + @_transparent + public init(unicodeScalarLiteral value: Unicode.Scalar) { + self = Self(value.value) + } + func foo() -> String { + return "Foo!" + } +} + +let a: [Int16] = ["a", "b", "c", "d", "e"] +let b: [UInt32] = ["a", "b", "c", "d", "e"] + +protocol P { +} +protocol Q { + func foo2() -> String +} +extension P {//}: Q { + func foo2() -> String { + return "Foo2 \(self)!" + } +} +class C {} +struct S { + let a = 99 +} +protocol P2: P {} +extension P2 {} +extension Numeric: P2 {} + +extension C: P2 {} +extension S: P2 {} + +// CHECK: Foo2 main.C! +print(C().foo2()) +// CHECK: Foo2 S(a: 99)! +print(S().foo2()) +// CHECK: Foo2 99! +print(Int8(99).foo2()) +// CHECK: Foo2 99! +print(UInt32(99).foo2()) +// CHECK: Foo! +print(Int8(99).foo()) +// CHECK: Foo! +print(Int32(99).foo()) +// CHECK: Foo! +print(UInt32(99).foo()) + +// CHECK: [97, 98, 99, 100, 101] +print(a) +// CHECK: [97, 98, 99, 100, 101] +print(b) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(a.map {$0.foo2()}) +// CHECK: ["Foo!", "Foo!", "Foo!", "Foo!", "Foo!"] +print(b.map {$0.foo()}) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(b.map {$0.foo2()}) + +public func use(_ value: Int) -> Int { + return value + "1" // ← Used ExpressibleByUnicodeScalarLiteral +} +//public func use(_ value: T) -> T +// where T : FixedWidthInteger { +// return value + "1" // ← Used ExpressibleByUnicodeScalarLiteral +//} + +print(use(1)) + +let c: P2 = 99.0 +// CHECK: Foo2 99.0! +print(c.foo2()) diff --git a/test/decl/ext/protocol.swift b/test/decl/ext/protocol.swift index 7ef336aef5c38..f0e6885743cb9 100644 --- a/test/decl/ext/protocol.swift +++ b/test/decl/ext/protocol.swift @@ -951,7 +951,7 @@ enum Foo : Int, ReallyRaw { protocol BadProto1 { } protocol BadProto2 { } -extension BadProto1 : BadProto2 { } // expected-error{{extension of protocol 'BadProto1' cannot have an inheritance clause}} +extension BadProto1 : BadProto2 { } // now possible extension BadProto2 { struct S { } // expected-error{{type 'S' cannot be nested in protocol extension of 'BadProto2'}}