diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 81294f101e24f..4aa509f31ab7c 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -332,6 +332,15 @@ class ASTContext final { llvm::BumpPtrAllocator & getAllocator(AllocationArena arena = AllocationArena::Permanent) const; + /// Track protocol extensions that inherit to determine witness tables to emit. + mutable SmallVector ExtensionsWithConformances; + + /// Protocol conformances that are a result of a protocol extension + mutable llvm::DenseMap> ExtendedConformances; + + friend class ProtocolDecl; + public: /// Allocate - Allocate memory from the ASTContext bump pointer. void *Allocate(unsigned long bytes, unsigned alignment, @@ -716,6 +725,19 @@ class ASTContext final { /// one. void loadExtensions(NominalTypeDecl *nominal, unsigned previousGeneration); + /// Keep track of protocol extensions that include a conformance. + /// + /// \param extWithConformances A (protocol) extension that has conformances + void addExtensionWithConformances(ExtensionDecl *extWithConformances) { + ExtensionsWithConformances.emplace_back(extWithConformances); + } + + /// Iterate over conformances arising from protocol extensions. + /// + /// \param emitWitness A function to call to emit a witness table. + void forEachExtendedConformance( + std::function emitWitness); + /// Load the methods within the given class that produce /// Objective-C class or instance methods with the given selector. /// diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 3821c33c631bf..3b36fb65bbea4 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -1735,7 +1735,7 @@ class ExtensionDecl final : public GenericContext, public Decl, MutableArrayRef getInherited() { return Inherited; } ArrayRef getInherited() const { return Inherited; } - void setInherited(MutableArrayRef i) { Inherited = i; } + void setInherited(MutableArrayRef i); bool hasDefaultAccessLevel() const { return Bits.ExtensionDecl.DefaultAndMaxAccessLevel != 0; @@ -3338,9 +3338,6 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// a given nominal type. mutable ConformanceLookupTable *ConformanceTable = nullptr; - /// Prepare the conformance table. - void prepareConformanceTable() const; - /// Returns the protocol requirements that \c Member conforms to. ArrayRef getSatisfiedProtocolRequirementsForMember(const ValueDecl *Member, @@ -3450,6 +3447,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// conform, such as AnyObject (for classes). void getImplicitProtocols(SmallVectorImpl &protocols); + /// Prepare the conformance table (also acts as accessor). + ConformanceLookupTable *prepareConformanceTable() const; + /// Look for conformances of this nominal type to the given /// protocol. /// @@ -4277,6 +4277,12 @@ class ProtocolDecl final : public NominalTypeDecl { return const_cast(this)->getInheritedProtocolsSlow(); } + /// An extension has inherited a new protocol + void inheritedProtocolsChanged(); + + //// Has a conformance come about from an extension? + bool isExtendedConformance(ProtocolDecl *proto); + /// Determine whether this protocol has a superclass. bool hasSuperclass() const { return (bool)getSuperclassDecl(); } diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 380b28e49d980..e3e68f5b20a5d 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -1711,6 +1711,8 @@ NOTE(composition_in_extended_type_alternative,none, ERROR(extension_access_with_conformances,none, "%0 modifier cannot be used with extensions that declare " "protocol conformances", (DeclAttribute)) +ERROR(protocol_extension_access_with_conformances,none, + "protocol extensions with conformances must currently be public", ()) ERROR(extension_metatype,none, "cannot extend a metatype %0", (Type)) ERROR(extension_specialization,none, @@ -1723,8 +1725,10 @@ NOTE(extension_stored_property_fixit,none, ERROR(extension_nongeneric_trailing_where,none, "trailing 'where' clause for extension of non-generic type %0", (DeclName)) -ERROR(extension_protocol_inheritance,none, - "extension of protocol %0 cannot have an inheritance clause", (DeclName)) +WARNING(extension_protocol_inheritance,none, + "inheritance clause in extension of protocol %0. " + "use -enable-conforming-protocol-extensions to remove this warning", + (DeclName)) ERROR(objc_generic_extension_using_type_parameter,none, "extension of a generic Objective-C class cannot access the class's " "generic parameters at runtime", ()) diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 8c8da5701f6a5..f3b878736fa13 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -112,6 +112,9 @@ namespace swift { /// Detect and automatically import modules' cross-import overlays. bool EnableCrossImportOverlays = false; + /// Gate for conforming protocol extensions code. + bool EnableConformingExtensions = false; + /// /// Support for alternate usage modes /// diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index b85b6b415a47d..1122049585775 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -504,6 +504,11 @@ def enable_experimental_concise_pound_file : Flag<["-"], Flags<[FrontendOption]>, HelpText<"Enable experimental concise '#file' identifier">; +def enable_conforming_protocol_extensions : Flag<["-"], + "enable-conforming-protocol-extensions">, + Flags<[FrontendOption]>, + HelpText<"Enable experimental feature to allow protocol extensions to conform to protocols">; + // Diagnostic control options def suppress_warnings : Flag<["-"], "suppress-warnings">, Flags<[FrontendOption]>, diff --git a/include/swift/SIL/SILWitnessVisitor.h b/include/swift/SIL/SILWitnessVisitor.h index 4c8bfeb4af50d..b560131f38012 100644 --- a/include/swift/SIL/SILWitnessVisitor.h +++ b/include/swift/SIL/SILWitnessVisitor.h @@ -50,7 +50,15 @@ template class SILWitnessVisitor : public ASTVisitor { public: void visitProtocolDecl(ProtocolDecl *protocol) { + // This chicanery is to move conformances arising from + // protocol extensions to the end of the witness table. + visitProtocolDecl(protocol, false); + visitProtocolDecl(protocol, true); + } + + void visitProtocolDecl(ProtocolDecl *protocol, bool includeExtended) { // The protocol conformance descriptor gets added first. + if (!includeExtended) asDerived().addProtocolConformanceDescriptor(); for (const auto &reqt : protocol->getRequirementSignature()) { @@ -79,12 +87,14 @@ template class SILWitnessVisitor : public ASTVisitor { assert(type->isEqual(protocol->getSelfInterfaceType())); assert(parameter->getDepth() == 0 && parameter->getIndex() == 0 && "non-self type parameter in protocol"); - asDerived().addOutOfLineBaseProtocol(requirement); + if (protocol->isExtendedConformance(requirement) == includeExtended) + asDerived().addOutOfLineBaseProtocol(requirement); continue; } // Otherwise, add an associated requirement. AssociatedConformance assocConf(protocol, type, requirement); + if (!includeExtended) asDerived().addAssociatedConformance(assocConf); continue; } @@ -93,6 +103,7 @@ template class SILWitnessVisitor : public ASTVisitor { } // Add the associated types. + if (!includeExtended) for (auto *associatedType : protocol->getAssociatedTypeMembers()) { // If this is a new associated type (which does not override an // existing associated type), add it. @@ -100,10 +111,11 @@ template class SILWitnessVisitor : public ASTVisitor { asDerived().addAssociatedType(AssociatedType(associatedType)); } - if (asDerived().shouldVisitRequirementSignatureOnly()) - return; +// if (asDerived().shouldVisitRequirementSignatureOnly()) +// return; // Visit the witnesses for the direct members of a protocol. + if (!includeExtended) for (Decl *member : protocol->getMembers()) { ASTVisitor::visit(member); } diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 2fe478a3ecfc4..bd6b2ddb3b2a3 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -2620,7 +2620,8 @@ class Verifier : public ASTWalker { } auto proto = conformance->getProtocol(); - if (normal->getDeclContext() != conformingDC) { + if (normal->getDeclContext() != conformingDC && + !isa(conformingDC)) { Out << "AST verification error: conformance of " << nominal->getName().str() << " to protocol " << proto->getName().str() << " is in the wrong context.\n" diff --git a/lib/AST/ConformanceLookupTable.cpp b/lib/AST/ConformanceLookupTable.cpp index 414d3d8803984..677153fdd9e3d 100644 --- a/lib/AST/ConformanceLookupTable.cpp +++ b/lib/AST/ConformanceLookupTable.cpp @@ -25,6 +25,7 @@ #include "swift/AST/ProtocolConformance.h" #include "swift/AST/ProtocolConformanceRef.h" #include "llvm/Support/SaveAndRestore.h" +#include "../Sema/TypeCheckProtocol.h" using namespace swift; @@ -475,11 +476,60 @@ void ConformanceLookupTable::addInheritedProtocols( } } +void ConformanceLookupTable::addImpliedProtocols( + NominalTypeDecl *nominal, + llvm::PointerUnion decl, + ConformanceSource source, + llvm::DenseMap &Seen, + Propagator *propagator) { + bool anyObject = false; + ExtensionDecl *extension = decl.dyn_cast(); + NominalTypeDecl *extending = extension ? extension->getExtendedNominal() : + dyn_cast(decl.dyn_cast()); + + // Closure to register inherited protocols against extending. + Propagator registerWitness = [&](ProtocolDecl *inheritedProto) { + if (extension && extending) { + auto *table = extending->prepareConformanceTable(); + table->NominalConformancesFromExtension[extension][nominal][inheritedProto] = true; + } + // Continue propagating up stack. + if (propagator) + (*propagator)(inheritedProto); + }; + + // Find all of the protocols in the inheritance list. + unsigned count = 0, expected = extension ? extension->getInherited().size() : 0; + for (const auto &found : + getDirectlyInheritedNominalTypeDecls(decl, anyObject)) { + if (Seen.find(found.Item) == Seen.end()) { + Seen[found.Item] = extending; // prevent circularity (diagnose somewhere?) + addImpliedProtocols(nominal, found.Item, source, Seen, ®isterWitness); + } + count++; + } + assert(!(count < expected)); + + if (!extension && extending) { + for (ExtensionDecl *ext : extending->getExtensions()) + addImpliedProtocols(nominal, ext, source, Seen, ®isterWitness); + + if (auto proto = dyn_cast(extending)) { + addProtocol(proto, extension ? extension->getLoc() : extending->getLoc(), source); + if (propagator) + (*propagator)(/*inheritedProto*/proto); + } + } +} + void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc) { // Note: recursive type-checking implies that AllConformances // may be reallocated during this traversal, so pay the lookup cost // during each iteration. + llvm::DenseMap Seen; + unsigned topLevel = AllConformances[dc].size(); + for (unsigned i = 0; i != AllConformances[dc].size(); ++i) { /// FIXME: Avoid the possibility of an infinite loop by fixing the root /// cause instead (incomplete circularity detection). @@ -504,8 +554,9 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, } } - addInheritedProtocols(conformingProtocol, - ConformanceSource::forImplied(conformanceEntry)); + if (i < topLevel) + addImpliedProtocols(nominal, conformingProtocol, + ConformanceSource::forImplied(conformanceEntry), Seen); } } @@ -1183,3 +1234,85 @@ void ConformanceLookupTable::dump(raw_ostream &os) const { } } +// Miscellaneous code added to implement conforming protocol extensions + +void ConformanceLookupTable::addExtendedConformances(const ExtensionDecl *ext, + SmallVectorImpl &conformances) { + MultiConformanceChecker groupChecker(ext->getASTContext()); + for (auto &ext : NominalConformancesFromExtension) + for (auto &nominalPair : ext.second) { + NominalTypeDecl *nominal = nominalPair.first; + if (nominal->getSelfProtocolDecl()) + continue; + for (auto &protocolPair : nominalPair.second) { + ProtocolDecl *proto = protocolPair.first; + auto table = nominal->prepareConformanceTable(); + auto entry = table->Conformances.find(proto); + if (entry == table->Conformances.end()) + continue; + + auto conformance = table->getConformance(nominal, entry->second.back()); + if (auto normal = dyn_cast(conformance)) { + conformances.push_back(conformance); + groupChecker.addConformance(normal); + } + } + } + + if (groupChecker.checkAllConformances()) + exit(EXIT_FAILURE); +} + +void ConformanceLookupTable::invalidate(NominalTypeDecl *nomimal) { + for (auto &extInfo : NominalConformancesFromExtension) + for (auto &toInvalidate : extInfo.second) { + NominalTypeDecl *nominal = toInvalidate.first; + if (auto *proto = dyn_cast(nominal)) + proto->inheritedProtocolsChanged(); + else + nominal->prepareConformanceTable()->invalidate(nominal); + } + + LastProcessed.clear(); +} + +void ProtocolDecl::inheritedProtocolsChanged() { + Bits.ProtocolDecl.InheritedProtocolsValid = false; + prepareConformanceTable()->invalidate(this); + RequirementSignature = nullptr; +} + +bool ProtocolDecl::isExtendedConformance(ProtocolDecl *proto) { + if (!Bits.ProtocolDecl.InheritedProtocolsValid) + (void)getInheritedProtocolsSlow(); + auto &extendedProtocols = getASTContext().ExtendedConformances; + auto extendedProto = extendedProtocols.find(this); + return extendedProto != extendedProtocols.end() && + extendedProto->second.find(proto) != extendedProto->second.end(); +} + +void ExtensionDecl::setInherited(MutableArrayRef i) { + Inherited = i; + if (!Inherited.empty()) { + getASTContext().addExtensionWithConformances(this); + + if (hasBeenBound()) + if (auto *proto = getExtendedProtocolDecl()) + proto->inheritedProtocolsChanged(); + } +} + +void ASTContext::forEachExtendedConformance( + std::function emitWitness) { + for (ExtensionDecl *ext : ExtensionsWithConformances) + if (ext->hasBeenBound()) + if (ProtocolDecl *proto = ext->getExtendedProtocolDecl()) { + SmallVector result; + proto->prepareConformanceTable()->addExtendedConformances(ext, result); + for (auto conformance : result) + if (auto *normal = dyn_cast(conformance)) + if (!conformance->getType()->getCanonicalType() + ->getAnyNominal()->getSelfProtocolDecl()) + emitWitness(normal); + } +} diff --git a/lib/AST/ConformanceLookupTable.h b/lib/AST/ConformanceLookupTable.h index a27c89f03e063..d42eb2d905691 100644 --- a/lib/AST/ConformanceLookupTable.h +++ b/lib/AST/ConformanceLookupTable.h @@ -319,6 +319,12 @@ class ConformanceLookupTable { llvm::DenseMap> ConformingDeclMap; + /// Tracks nominals that have an implied conformance from inheriting protocol extension + /// Used to know nominals that need to refresh their conformances and which witnesses to emit. + llvm::DenseMap>> NominalConformancesFromExtension; + + /// Indicates whether we are visiting the superclass. bool VisitingSuperclass = false; @@ -331,6 +337,16 @@ class ConformanceLookupTable { llvm::PointerUnion decl, ConformanceSource source); + /// Used to propagate conformances up to protocols being extended (by conformances). + using Propagator = std::function; + + /// Recursively add inherited protocols, register conformance infered from protocol extension. + void addImpliedProtocols(NominalTypeDecl *nominal, + llvm::PointerUnion decl, + ConformanceSource source, + llvm::DenseMap &Seen, + Propagator *propagator = nullptr); + /// Expand the implied conformances for the given DeclContext. void expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc); @@ -413,6 +429,9 @@ class ConformanceLookupTable { /// Create a new conformance lookup table. ConformanceLookupTable(ASTContext &ctx); + /// Force subseqent recalulation of conformances + void invalidate(NominalTypeDecl *nomimal); + /// Destroy the conformance table. void destroy(); @@ -443,6 +462,10 @@ class ConformanceLookupTable { SmallVectorImpl *conformances, SmallVectorImpl *diagnostics); + /// Add conformances implied by an inheriting protocol extension + void addExtendedConformances(const ExtensionDecl *ext, + SmallVectorImpl &conformances); + /// Retrieve the complete set of protocols to which this nominal /// type conforms. void getAllProtocols(NominalTypeDecl *nominal, diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 9df38af966127..8baec49b204bd 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -1231,12 +1231,12 @@ ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc, Decl(DeclKind::Extension, parent), IterableDeclContext(IterableDeclContextKind::ExtensionDecl), ExtensionLoc(extensionLoc), - ExtendedTypeRepr(extendedType), - Inherited(inherited) + ExtendedTypeRepr(extendedType) { Bits.ExtensionDecl.DefaultAndMaxAccessLevel = 0; Bits.ExtensionDecl.HasLazyConformances = false; setTrailingWhereClause(trailingWhereClause); + setInherited(inherited); } ExtensionDecl *ExtensionDecl::create(ASTContext &ctx, SourceLoc extensionLoc, @@ -4645,6 +4645,16 @@ ProtocolDecl::getInheritedProtocolsSlow() { } auto &ctx = getASTContext(); + // Protocol extensions with conformances. + for (auto ext : getExtensions()) + for (const auto found : + getDirectlyInheritedNominalTypeDecls(ext, anyObject)) + if (auto proto = dyn_cast(found.Item)) + if (known.insert(proto).second) { + result.push_back(proto); + ctx.ExtendedConformances[this][proto] = ext; + } + InheritedProtocols = ctx.AllocateCopy(result); return InheritedProtocols; } diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index 6ddc863685778..cf7baa9d65582 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -3838,6 +3838,22 @@ static ConstraintResult visitInherited( visitInherited(inheritedType, inherited.getTypeRepr()); } + if (auto *protoDecl = dyn_cast_or_null(typeDecl)) + for (auto *ext : protoDecl->getExtensions()) { + ArrayRef inheritedTypes = ext->getInherited(); + for (unsigned index : indices(inheritedTypes)) { + Type inheritedType + = evaluateOrDefault(evaluator, + InheritedTypeRequest{ext, index, + TypeResolutionStage::Structural}, + Type()); + if (!inheritedType) continue; + + const auto &inherited = inheritedTypes[index]; + visitInherited(inheritedType, inherited.getTypeRepr()); + } + } + return result; } @@ -4456,6 +4472,23 @@ ConstraintResult GenericSignatureBuilder::addTypeRequirement( anyErrors = true; } + // Protocol extensions with conformances. + if (auto *protoDecl = + dyn_cast_or_null(constraintType->getAnyNominal())) { + auto &conforms = resolvedSubject.getEquivalenceClass(*this)->conformsTo; + + for (auto *ext : protoDecl->getExtensions()) + for (auto &typeLoc : ext->getInherited()) + if (auto inheritedTy = typeLoc.getType()) + if (auto *inheritedProto = + dyn_cast_or_null(inheritedTy->getAnyNominal())) + if (conforms.find(inheritedProto) == conforms.end() && + isErrorResult(addConformanceRequirement(resolvedSubject, + inheritedProto, + source))) + anyErrors = true; + } + return anyErrors ? ConstraintResult::Conflicting : ConstraintResult::Resolved; } diff --git a/lib/AST/ProtocolConformance.cpp b/lib/AST/ProtocolConformance.cpp index 84626c26178cc..86508e9f5f4ba 100644 --- a/lib/AST/ProtocolConformance.cpp +++ b/lib/AST/ProtocolConformance.cpp @@ -1221,9 +1221,9 @@ ProtocolConformance::getInheritedConformance(ProtocolDecl *protocol) const { } #pragma mark Protocol conformance lookup -void NominalTypeDecl::prepareConformanceTable() const { +ConformanceLookupTable *NominalTypeDecl::prepareConformanceTable() const { if (ConformanceTable) - return; + return ConformanceTable; auto mutableThis = const_cast(this); ASTContext &ctx = getASTContext(); @@ -1236,7 +1236,7 @@ void NominalTypeDecl::prepareConformanceTable() const { if (file->getKind() != FileUnitKind::Source && file->getKind() != FileUnitKind::ClangModule && file->getKind() != FileUnitKind::DWARFModule) { - return; + return ConformanceTable; } SmallPtrSet protocols; @@ -1270,6 +1270,8 @@ void NominalTypeDecl::prepareConformanceTable() const { addSynthesized(KnownProtocolKind::RawRepresentable); } } + + return ConformanceTable; } bool NominalTypeDecl::lookupConformance( diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp index a42a76143c856..bd87ba46bd547 100644 --- a/lib/Driver/ToolChains.cpp +++ b/lib/Driver/ToolChains.cpp @@ -256,6 +256,8 @@ static void addCommonFrontendArgs(const ToolChain &TC, const OutputInfo &OI, options::OPT_enable_experimental_concise_pound_file); inputArgs.AddLastArg(arguments, options::OPT_verify_incremental_dependencies); + inputArgs.AddLastArg(arguments, + options::OPT_enable_conforming_protocol_extensions); // Pass on any build config options inputArgs.AddAllArgs(arguments, options::OPT_D); diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 3cc2a481a5fe8..4a459de80d3ac 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -536,6 +536,8 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, Opts.EnableConcisePoundFile = Args.hasArg(OPT_enable_experimental_concise_pound_file); + Opts.EnableConformingExtensions = + Args.hasArg(OPT_enable_conforming_protocol_extensions); Opts.EnableCrossImportOverlays = Args.hasFlag(OPT_enable_cross_import_overlays, diff --git a/lib/IRGen/GenProto.cpp b/lib/IRGen/GenProto.cpp index 3a6567b94b21c..8846607c445f3 100644 --- a/lib/IRGen/GenProto.cpp +++ b/lib/IRGen/GenProto.cpp @@ -1080,9 +1080,9 @@ mapConformanceIntoContext(IRGenModule &IGM, const RootProtocolConformance &conf, WitnessIndex ProtocolInfo::getAssociatedTypeIndex( IRGenModule &IGM, AssociatedType assocType) const { - assert(!IGM.isResilient(assocType.getSourceProtocol(), - ResilienceExpansion::Maximal) && - "Cannot ask for the associated type index of non-resilient protocol"); +// assert(!IGM.isResilient(assocType.getSourceProtocol(), +// ResilienceExpansion::Maximal) && +// "Cannot ask for the associated type index of non-resilient protocol"); for (auto &witness : getWitnessEntries()) { if (witness.matchesAssociatedType(assocType)) return getNonBaseWitnessIndex(&witness); @@ -1948,6 +1948,8 @@ void IRGenModule::emitProtocolConformance( getAddrOfProtocolConformanceDescriptor(conformance, init.finishAndCreateFuture())); var->setConstant(true); + if (record.wtable->getLinkage() == SILLinkage::Private) + var->setLinkage(llvm::GlobalVariable::LinkageTypes::PrivateLinkage); setTrueConstGlobal(var); } @@ -1986,8 +1988,8 @@ void IRGenerator::ensureRelativeSymbolCollocation(SILDefaultWitnessTable &wt) { const ProtocolInfo &IRGenModule::getProtocolInfo(ProtocolDecl *protocol, ProtocolInfoKind kind) { // If the protocol is resilient, we cannot know the full witness table layout. - assert(!isResilient(protocol, ResilienceExpansion::Maximal) || - kind == ProtocolInfoKind::RequirementSignature); +// assert(!isResilient(protocol, ResilienceExpansion::Maximal) || +// kind == ProtocolInfoKind::RequirementSignature); return Types.getProtocolInfo(protocol, kind); } @@ -2155,6 +2157,9 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) { tableSize = wtableBuilder.getTableSize(); instantiationFunction = wtableBuilder.buildInstantiationFunction(); + + if (wt->getLinkage() == SILLinkage::Private) + global->setLinkage(llvm::GlobalVariable::LinkageTypes::PrivateLinkage); } else { // Build the witness table. ResilientWitnessTableBuilder wtableBuilder(*this, wt); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 788d8edde0701..e4f15344ee076 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1831,6 +1831,11 @@ class SILGenModuleRAII { continue; SGM.visit(TD); } + + SGM.getASTContext() + .forEachExtendedConformance([&](NormalProtocolConformance *normal) { + SGM.getWitnessTable(normal, /*emitAsPrivate*/true); + }); } SILGenModuleRAII(SILModule &M, ModuleDecl *SM) : SGM{M, SM} {} diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h index 287d656fe1261..811d28cc9afb5 100644 --- a/lib/SILGen/SILGen.h +++ b/lib/SILGen/SILGen.h @@ -333,7 +333,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { void emitObjCDestructorThunk(DestructorDecl *destructor); /// Get or emit the witness table for a protocol conformance. - SILWitnessTable *getWitnessTable(NormalProtocolConformance *conformance); + SILWitnessTable *getWitnessTable(NormalProtocolConformance *conformance, + bool emitAsPrivate = false); /// Emit a protocol witness entry point. SILFunction * diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 3f19296f7478d..d4be1571b4d36 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -615,13 +615,17 @@ class SILGenConformance : public SILGenWitnessTable { } // end anonymous namespace SILWitnessTable * -SILGenModule::getWitnessTable(NormalProtocolConformance *conformance) { +SILGenModule::getWitnessTable(NormalProtocolConformance *conformance, + bool emitAsPrivate) { // If we've already emitted this witness table, return it. auto found = emittedWitnessTables.find(conformance); if (found != emittedWitnessTables.end()) return found->second; - SILWitnessTable *table = SILGenConformance(*this, conformance).emit(); + SILGenConformance conf(*this, conformance); + if (emitAsPrivate) + conf.Linkage = SILLinkage::Private; + SILWitnessTable *table = conf.emit(); emittedWitnessTables.insert({conformance, table}); return table; diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index f2a5f84f4689a..a9fe9fa5f7223 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -6883,8 +6883,12 @@ Expr *ExprRewriter::convertLiteralInPlace(Expr *literal, // Extract the literal type. Type builtinLiteralType = conformance.getTypeWitnessByName(type, literalType); - if (builtinLiteralType->hasError()) + if (builtinLiteralType->hasError()) { + cs.getASTContext().Diags.diagnose(literal->getLoc(), + diag::type_does_not_conform, + type, protocol->getDeclaredType()); return nullptr; + } // Perform the builtin conversion. if (!convertLiteralInPlace(literal, builtinLiteralType, nullptr, diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 1a7690fba878f..57c9a89e77c8f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -707,8 +707,10 @@ bool AttributeChecker::visitAbstractAccessControlAttr( if (auto extension = dyn_cast(D)) { if (!extension->getInherited().empty()) { - diagnoseAndRemoveAttr(attr, diag::extension_access_with_conformances, - attr); + if (!extension->getExtendedProtocolDecl()) + diagnoseAndRemoveAttr(attr, diag::extension_access_with_conformances, attr); + else if (!(attr->getAccess() == AccessLevel::Public)) + diagnoseAndRemoveAttr(attr, diag::protocol_extension_access_with_conformances); return true; } } diff --git a/lib/Sema/TypeCheckDeclPrimary.cpp b/lib/Sema/TypeCheckDeclPrimary.cpp index 2f18f2d829884..6b65ae3388a38 100644 --- a/lib/Sema/TypeCheckDeclPrimary.cpp +++ b/lib/Sema/TypeCheckDeclPrimary.cpp @@ -85,11 +85,19 @@ static void checkInheritanceClause( // Protocol extensions cannot have inheritance clauses. if (auto proto = ext->getExtendedProtocolDecl()) { if (!inheritedClause.empty()) { - ext->diagnose(diag::extension_protocol_inheritance, - proto->getName()) - .highlight(SourceRange(inheritedClause.front().getSourceRange().Start, - inheritedClause.back().getSourceRange().End)); - return; + // Force recalculation conformance table + ext->setInherited(inheritedClause); + + auto *attr = ext->getAttrs().getAttribute(); + if (!attr || attr->getAccess() < AccessLevel::Public) + ext->diagnose(diag::protocol_extension_access_with_conformances) + .fixItInsert(ext->getStartLoc(), "public "); + + // Warning (eventually error?) here is now gated + if (!ext->getASTContext().LangOpts.EnableConformingExtensions) + ext->diagnose(diag::extension_protocol_inheritance, proto->getName()) + .highlight(SourceRange(inheritedClause.front().getSourceRange().Start, + inheritedClause.back().getSourceRange().End)); } } } else { diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 5cfe36429cbe5..4f44ec9bfa76c 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -1382,48 +1382,6 @@ RequirementCheck WitnessChecker::checkWitness(ValueDecl *requirement, # pragma mark Witness resolution -/// This is a wrapper of multiple instances of ConformanceChecker to allow us -/// to diagnose and fix code from a more global perspective; for instance, -/// having this wrapper can help issue a fixit that inserts protocol stubs from -/// multiple protocols under checking. -class swift::MultiConformanceChecker { - ASTContext &Context; - llvm::SmallVector UnsatisfiedReqs; - llvm::SmallVector AllUsedCheckers; - llvm::SmallVector AllConformances; - llvm::SetVector MissingWitnesses; - llvm::SmallPtrSet CoveredMembers; - - /// Check one conformance. - ProtocolConformance * checkIndividualConformance( - NormalProtocolConformance *conformance, bool issueFixit); - - /// Determine whether the given requirement was left unsatisfied. - bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); -public: - MultiConformanceChecker(ASTContext &ctx) : Context(ctx) {} - - ASTContext &getASTContext() const { return Context; } - - /// Add a conformance into the batched checker. - void addConformance(NormalProtocolConformance *conformance) { - AllConformances.push_back(conformance); - } - - /// Peek the unsatisfied requirements collected during conformance checking. - ArrayRef getUnsatisfiedRequirements() { - return llvm::makeArrayRef(UnsatisfiedReqs); - } - - /// Whether this member is "covered" by one of the conformances. - bool isCoveredMember(ValueDecl *member) const { - return CoveredMembers.count(member) > 0; - } - - /// Check all conformances and emit diagnosis globally. - void checkAllConformances(); -}; - bool MultiConformanceChecker:: isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req) { if (conformance->isInvalid()) return false; @@ -1449,7 +1407,7 @@ isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req) { return false; } -void MultiConformanceChecker::checkAllConformances() { +bool MultiConformanceChecker::checkAllConformances() { bool anyInvalid = false; for (unsigned I = 0, N = AllConformances.size(); I < N; I ++) { auto *conformance = AllConformances[I]; @@ -1474,7 +1432,7 @@ void MultiConformanceChecker::checkAllConformances() { } // If all missing witnesses are issued with fixits, we are done. if (MissingWitnesses.empty()) - return; + return anyInvalid; // Otherwise, backtrack to the last checker that has missing witnesses // and diagnose missing witnesses from there. @@ -1484,6 +1442,8 @@ void MultiConformanceChecker::checkAllConformances() { It->diagnoseMissingWitnesses(MissingWitnessDiagnosisKind::FixItOnly); } } + + return true; } static void diagnoseConformanceImpliedByConditionalConformance( diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 9050e88102637..7f1fbefac3870 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -948,6 +948,43 @@ llvm::TinyPtrVector findWitnessedObjCRequirements( const ValueDecl *witness, bool anySingleRequirement = false); +class MultiConformanceChecker { + ASTContext &Context; + llvm::SmallVector UnsatisfiedReqs; + llvm::SmallVector AllUsedCheckers; + llvm::SmallVector AllConformances; + llvm::SetVector MissingWitnesses; + llvm::SmallPtrSet CoveredMembers; + + /// Check one conformance. + ProtocolConformance * checkIndividualConformance( + NormalProtocolConformance *conformance, bool issueFixit); + + /// Determine whether the given requirement was left unsatisfied. + bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); +public: + MultiConformanceChecker(ASTContext &ctx) : Context(ctx) {} + + ASTContext &getASTContext() const { return Context; } + + /// Add a conformance into the batched checker. + void addConformance(NormalProtocolConformance *conformance) { + AllConformances.push_back(conformance); + } + + /// Peek the unsatisfied requirements collected during conformance checking. + ArrayRef getUnsatisfiedRequirements() { + return llvm::makeArrayRef(UnsatisfiedReqs); + } + + /// Whether this member is "covered" by one of the conformances. + bool isCoveredMember(ValueDecl *member) const { + return CoveredMembers.count(member) > 0; + } + + /// Check all conformances and emit diagnosis globally. + bool checkAllConformances(); +}; } #endif // SWIFT_SEMA_PROTOCOL_H diff --git a/test/decl/conforming_extensions.swift b/test/decl/conforming_extensions.swift new file mode 100644 index 0000000000000..62ca11fa19056 --- /dev/null +++ b/test/decl/conforming_extensions.swift @@ -0,0 +1,96 @@ +// RUN: %target-run-simple-swift -enable-conforming-protocol-extensions %s | %FileCheck %s + +public extension FixedWidthInteger: ExpressibleByUnicodeScalarLiteral { + @_transparent + init(unicodeScalarLiteral value: Unicode.Scalar) { + self = Self(value.value) + } + func foo() -> String { + return "Foo!" + } +} + +let a: [Int16] = ["a", "b", "c", "d", "e"] +let b: [UInt32] = ["a", "b", "c", "d", "e"] + +protocol Q { + var bar: Double { get } + func foo2() -> String +} +public protocol P { +} +public extension P: Q { + var bar: Double { return -888 } + func foo2() -> String { + return "Foo2 \(self)!" + } +} +class C {} +struct S { + let a = 99 +} +public protocol P2: P {} + +extension C: P2 {} +extension S: P2 {} +//extension FixedWidthInteger: P2 {} +extension Numeric: P2 {} + +//extension P2 {} + +//var a2: String? +//_ = a2! + +// CHECK: Foo2 main.C! +print(C().foo2()) +// CHECK: Foo2 S(a: 99)! +print(S().foo2()) +// CHECK: Foo2 99! +print(Int8(99).foo2()) +// CHECK: Foo2 99! +//extension UInt32: P2 {} +print(UInt32(99).foo2()) +// CHECK: Foo! +print(Int8(99).foo()) +// CHECK: Foo! +print(Int32(99).foo()) +// CHECK: Foo! +print(UInt32(99).foo()) + +// CHECK: [97, 98, 99, 100, 101] +print(a) +// CHECK: [97, 98, 99, 100, 101] +print(b) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(a.map {$0.foo2()}) +// CHECK: ["Foo!", "Foo!", "Foo!", "Foo!", "Foo!"] +print(b.map {$0.foo()}) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(b.map {$0.foo2()}) + +public func use(_ value: Int) -> Int { + return value + "1" // ← Used ExpressibleByUnicodeScalarLiteral +} +//public func use(_ value: T) -> T +// where T : FixedWidthInteger { +// return value + "1" // ← Used ExpressibleByUnicodeScalarLiteral +//} + +print(use(1)) + +let u = UInt32(99) + +func aaa(_ b: P2) { + print(b.foo2()) +} + +aaa(u) + +aaa(88.0) + +print(99.0.foo2()) + +let v = b.map { $0 as P2 }.map { $0.foo2() } +print(v) + +print(99.bar) diff --git a/test/decl/ext/protocol.swift b/test/decl/ext/protocol.swift index 2628bbc8aab11..f3bdfbbdfd6d1 100644 --- a/test/decl/ext/protocol.swift +++ b/test/decl/ext/protocol.swift @@ -947,7 +947,7 @@ enum Foo : Int, ReallyRaw { protocol BadProto1 { } protocol BadProto2 { } -extension BadProto1 : BadProto2 { } // expected-error{{extension of protocol 'BadProto1' cannot have an inheritance clause}} +extension BadProto1 : BadProto2 { } // expected-warning{{inheritance clause in extension of protocol 'BadProto1'. use -enable-conforming-protocol-extensions to remove this warning}} extension BadProto2 { struct S { } // expected-error{{type 'S' cannot be nested in protocol extension of 'BadProto2'}}