diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 665fcb0894313..f1160da0e48b3 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -335,7 +335,14 @@ class ASTContext final { llvm::BumpPtrAllocator & getAllocator(AllocationArena arena = AllocationArena::Permanent) const; + /// Record of conformances that have come about by extending a protocol + llvm::DenseMap> ExtendedConformances; + public: + llvm::DenseMap> &getExtendedConformances() { return ExtendedConformances; } + /// Allocate - Allocate memory from the ASTContext bump pointer. void *Allocate(unsigned long bytes, unsigned alignment, AllocationArena arena = AllocationArena::Permanent) const { @@ -729,6 +736,12 @@ class ASTContext final { /// one. void loadExtensions(NominalTypeDecl *nominal, unsigned previousGeneration); + /// Iterate over conformances arising from protocol extensions. + /// + /// \param emitWitness A function to call to emit a witness table. + void forEachExtendedConformance(ModuleDecl *module, + std::function emitWitness); + /// Load the methods within the given class that produce /// Objective-C class or instance methods with the given selector. /// diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 78e04e69598cf..5b05d7201763b 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -1754,7 +1754,7 @@ class ExtensionDecl final : public GenericContext, public Decl, MutableArrayRef getInherited() { return Inherited; } ArrayRef getInherited() const { return Inherited; } - void setInherited(MutableArrayRef i) { Inherited = i; } + void setInherited(MutableArrayRef i); bool hasDefaultAccessLevel() const { return Bits.ExtensionDecl.DefaultAndMaxAccessLevel != 0; @@ -3367,9 +3367,6 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// a given nominal type. mutable ConformanceLookupTable *ConformanceTable = nullptr; - /// Prepare the conformance table. - void prepareConformanceTable() const; - /// Returns the protocol requirements that \c Member conforms to. ArrayRef getSatisfiedProtocolRequirementsForMember(const ValueDecl *Member, @@ -3479,6 +3476,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext { /// conform, such as AnyObject (for classes). void getImplicitProtocols(SmallVectorImpl &protocols); + /// Prepare the conformance table (also acts as accessor). + ConformanceLookupTable *prepareConformanceTable() const; + /// Look for conformances of this nominal type to the given /// protocol. /// @@ -4332,6 +4332,9 @@ class ProtocolDecl final : public NominalTypeDecl { /// Retrieve the set of protocols inherited from this protocol. ArrayRef getInheritedProtocols() const; + /// An extension has inherited a new protocol + void inheritedProtocolsChanged(); + /// Determine whether this protocol has a superclass. bool hasSuperclass() const { return (bool)getSuperclassDecl(); } @@ -4422,6 +4425,9 @@ class ProtocolDecl final : public NominalTypeDecl { /// contain 'Self' in 'parameter' or 'other' position. bool existentialTypeSupported() const; + /// Track conformances that have come about due to a protocol extension + void recordExtendedNominal(NominalTypeDecl *nomial, ExtensionDecl *ext); + private: void computeKnownProtocolKind() const; diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index e8581d3cbc515..4c6c06659115f 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -1751,6 +1751,8 @@ NOTE(composition_in_extended_type_alternative,none, ERROR(extension_access_with_conformances,none, "%0 modifier cannot be used with extensions that declare " "protocol conformances", (DeclAttribute)) +ERROR(protocol_extension_access_with_conformances,none, + "protocol extensions with conformances must currently be public", ()) ERROR(extension_metatype,none, "cannot extend a metatype %0", (Type)) ERROR(extension_specialization,none, @@ -1764,8 +1766,12 @@ ERROR(extension_nongeneric_trailing_where,none, "trailing 'where' clause for extension of non-generic type %0", (Identifier)) ERROR(extension_protocol_inheritance,none, - "extension of protocol %0 cannot have an inheritance clause", - (Identifier)) + "inheritance clause in extension of protocol %0. " + "use -enable-conforming-protocol-extensions to enable this feature", + (DeclName)) +ERROR(extension_protocol_limitation,none, + "cannot extend protocol %0 that %1 already conforms to in another module", + (DeclName, Type)) ERROR(objc_generic_extension_using_type_parameter,none, "extension of a generic Objective-C class cannot access the class's " "generic parameters at runtime", ()) diff --git a/include/swift/AST/NameLookupRequests.h b/include/swift/AST/NameLookupRequests.h index e71c789fbfb0a..2529addc4b667 100644 --- a/include/swift/AST/NameLookupRequests.h +++ b/include/swift/AST/NameLookupRequests.h @@ -182,7 +182,9 @@ class InheritedProtocolsRequest public: // Caching - bool isCached() const { return true; } + bool isCached() const { + return std::get<0>(getStorage())->areInheritedProtocolsValid(); + } Optional> getCachedResult() const; void cacheResult(ArrayRef value) const; diff --git a/include/swift/AST/ProtocolConformance.h b/include/swift/AST/ProtocolConformance.h index 1136fb5be73cd..cf623733d4eb3 100644 --- a/include/swift/AST/ProtocolConformance.h +++ b/include/swift/AST/ProtocolConformance.h @@ -96,11 +96,34 @@ class alignas(1 << DeclAlignInBits) ProtocolConformance { /// conformance definition. Type ConformingType; + // The conformance has been used during Sema. + mutable bool hasBeenReferenced = false; + + // Ad-hoc conformances from protocol extensions must be private. + mutable bool fromProtocolExtension = false; + protected: ProtocolConformance(ProtocolConformanceKind kind, Type conformingType) : Kind(kind), ConformingType(conformingType) {} public: + ProtocolConformance *recordReferenced() const { + hasBeenReferenced = true; + return const_cast(this); + } + + bool isInUse() { + return hasBeenReferenced; + } + + void makePrivate() const { + fromProtocolExtension = true; + } + + bool isFromProtocolExtension() const { + return fromProtocolExtension; + } + /// Determine the kind of protocol conformance. ProtocolConformanceKind getKind() const { return Kind; } diff --git a/include/swift/AST/ProtocolConformanceRef.h b/include/swift/AST/ProtocolConformanceRef.h index e155092f58db6..ddf8c4d98a460 100644 --- a/include/swift/AST/ProtocolConformanceRef.h +++ b/include/swift/AST/ProtocolConformanceRef.h @@ -85,9 +85,7 @@ class ProtocolConformanceRef { bool isConcrete() const { return !isInvalid() && Union.is(); } - ProtocolConformance *getConcrete() const { - return Union.get(); - } + ProtocolConformance *getConcrete() const; bool isAbstract() const { return !isInvalid() && Union.is(); diff --git a/include/swift/AST/Requirement.h b/include/swift/AST/Requirement.h index c8b90bbc8b631..2008763e86900 100644 --- a/include/swift/AST/Requirement.h +++ b/include/swift/AST/Requirement.h @@ -60,10 +60,13 @@ class Requirement { LayoutConstraint SecondLayout; }; + /// Module this requirement was derived from. Used to re-order witness table. + unsigned ModuleNumber; + public: /// Create a conformance or same-type requirement. - Requirement(RequirementKind kind, Type first, Type second) - : FirstTypeAndKind(first, kind), SecondType(second) { + Requirement(RequirementKind kind, Type first, Type second, unsigned moduleNumber = 0) + : FirstTypeAndKind(first, kind), SecondType(second), ModuleNumber(moduleNumber) { assert(first); assert(second); } @@ -89,6 +92,10 @@ class Requirement { return SecondType; } + unsigned getModuleNumber() const { + return ModuleNumber; + } + /// Subst the types involved in this requirement. /// /// The \c args arguments are passed through to Type::subst. This doesn't diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 9e4a053e972d4..010216027eb12 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -111,6 +111,9 @@ namespace swift { /// when using RequireExplicitAvailability. std::string RequireExplicitAvailabilityTarget; + /// Gate for conforming protocol extensions code. + bool EnableConformingExtensions = false; + /// If false, '#file' evaluates to the full path rather than a /// human-readable string. bool EnableConcisePoundFile = false; diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index b771b48b3895d..9a310b6170385 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -513,6 +513,11 @@ def enable_experimental_additive_arithmetic_derivation : Flags<[FrontendOption]>, HelpText<"Enable experimental 'AdditiveArithmetic' derived conformances">; +def enable_conforming_protocol_extensions : Flag<["-"], + "enable-conforming-protocol-extensions">, + Flags<[FrontendOption]>, + HelpText<"Enable experimental feature to conforming protocol extensions">; + def enable_experimental_concise_pound_file : Flag<["-"], "enable-experimental-concise-pound-file">, Flags<[FrontendOption]>, diff --git a/include/swift/SIL/SILWitnessVisitor.h b/include/swift/SIL/SILWitnessVisitor.h index 4c8bfeb4af50d..e5dee64f24237 100644 --- a/include/swift/SIL/SILWitnessVisitor.h +++ b/include/swift/SIL/SILWitnessVisitor.h @@ -50,8 +50,19 @@ template class SILWitnessVisitor : public ASTVisitor { public: void visitProtocolDecl(ProtocolDecl *protocol) { + // This chicanery is to move conformances arising from + // protocol extensions to the end of the witness table. + unsigned moduleNumber = 0; + while (visitProtocolDecl(protocol, moduleNumber++)) {} + } + + bool visitProtocolDecl(ProtocolDecl *protocol, unsigned moduleNumber) { + bool emittedConformance = false; + llvm::DenseMap seen; + // The protocol conformance descriptor gets added first. - asDerived().addProtocolConformanceDescriptor(); + if (moduleNumber == 0) + asDerived().addProtocolConformanceDescriptor(); for (const auto &reqt : protocol->getRequirementSignature()) { switch (reqt.getKind()) { @@ -72,6 +83,9 @@ template class SILWitnessVisitor : public ASTVisitor { if (!Lowering::TypeConverter::protocolRequiresWitnessTable(requirement)) continue; + if (reqt.getModuleNumber() != moduleNumber) + continue; + // If the type parameter is 'self', consider this to be protocol // inheritance. In the canonical signature, these should all // come before any protocol requirements on associated types. @@ -80,18 +94,23 @@ template class SILWitnessVisitor : public ASTVisitor { assert(parameter->getDepth() == 0 && parameter->getIndex() == 0 && "non-self type parameter in protocol"); asDerived().addOutOfLineBaseProtocol(requirement); + emittedConformance = true; continue; } // Otherwise, add an associated requirement. AssociatedConformance assocConf(protocol, type, requirement); asDerived().addAssociatedConformance(assocConf); + emittedConformance = true; continue; } } llvm_unreachable("bad requirement kind"); } + if (moduleNumber) + return emittedConformance; + // Add the associated types. for (auto *associatedType : protocol->getAssociatedTypeMembers()) { // If this is a new associated type (which does not override an @@ -100,13 +119,15 @@ template class SILWitnessVisitor : public ASTVisitor { asDerived().addAssociatedType(AssociatedType(associatedType)); } - if (asDerived().shouldVisitRequirementSignatureOnly()) - return; +// if (asDerived().shouldVisitRequirementSignatureOnly()) +// return true; // Visit the witnesses for the direct members of a protocol. for (Decl *member : protocol->getMembers()) { ASTVisitor::visit(member); } + + return true; } /// If true, only the base protocols and associated types will be visited. diff --git a/lib/AST/ConformanceLookupTable.cpp b/lib/AST/ConformanceLookupTable.cpp index 414d3d8803984..1baa39c3b84db 100644 --- a/lib/AST/ConformanceLookupTable.cpp +++ b/lib/AST/ConformanceLookupTable.cpp @@ -25,6 +25,7 @@ #include "swift/AST/ProtocolConformance.h" #include "swift/AST/ProtocolConformanceRef.h" #include "llvm/Support/SaveAndRestore.h" +#include "../Sema/TypeCheckProtocol.h" using namespace swift; @@ -266,13 +267,14 @@ void ConformanceLookupTable::inheritConformances(ClassDecl *classDecl, void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal, ConformanceStage stage) { + ++Updating; switch (stage) { case ConformanceStage::RecordedExplicit: // Record all of the explicit conformances. forEachInStage( stage, nominal, [&](NominalTypeDecl *nominal) { - addInheritedProtocols(nominal, + addInheritedProtocols(nominal, nominal, ConformanceSource::forExplicit(nominal)); }, [&](ExtensionDecl *ext, @@ -392,6 +394,7 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal, } break; } + --Updating; } void ConformanceLookupTable::loadAllConformances( @@ -422,16 +425,23 @@ bool ConformanceLookupTable::addProtocol(ProtocolDecl *protocol, SourceLoc loc, // recording). auto &conformanceEntries = Conformances[protocol]; if (kind == ConformanceEntryKind::Implied || + kind == ConformanceEntryKind::Explicit || kind == ConformanceEntryKind::Synthesized) { for (const auto *existingEntry : conformanceEntries) { switch (existingEntry->getKind()) { case ConformanceEntryKind::Explicit: case ConformanceEntryKind::Inherited: - return false; + // some subtle adjustments required here a table can be invalidated + // ... when inherited protocols change and reclaculated. + if (kind != ConformanceEntryKind::Explicit) + return false; + if (loc == existingEntry->getLoc()) + return false; + break; case ConformanceEntryKind::Implied: // Ignore implied circular protocol inheritance - if (existingEntry->getDeclContext() == dc) + if (existingEntry->getDeclContext() == dc && existingEntry->getLoc() == loc) return false; // An implied conformance is better than a synthesized one, unless @@ -463,16 +473,48 @@ bool ConformanceLookupTable::addProtocol(ProtocolDecl *protocol, SourceLoc loc, return true; } -void ConformanceLookupTable::addInheritedProtocols( +void ConformanceLookupTable::addInheritedProtocols(NominalTypeDecl *nominal, llvm::PointerUnion decl, ConformanceSource source) { // Find all of the protocols in the inheritance list. bool anyObject = false; for (const auto &found : - getDirectlyInheritedNominalTypeDecls(decl, anyObject)) { - if (auto proto = dyn_cast(found.Item)) - addProtocol(proto, found.Loc, source); - } + getDirectlyInheritedNominalTypeDecls(decl, anyObject)) + if (auto proto = dyn_cast(found.Item)) { + InheritedFrom[proto] = {decl, found.Loc}; + auto extended = isExtendedConformance(proto); + ExtensionDecl *protocolExtension = std::get<0>(extended); + SourceLoc loc = protocolExtension ? std::get<1>(extended) : found.Loc; + if (protocolExtension) + protocolExtension->getExtendedProtocolDecl() + ->recordExtendedNominal(nominal, protocolExtension); + proto->recordExtendedNominal(nominal, protocolExtension); + addProtocol(proto, loc, source); + } + + // Protocol extensions with conformances. + if (ProtocolDecl *extendedProtocol = + dyn_cast_or_null(decl.dyn_cast())) + for (ExtensionDecl *ext : extendedProtocol->getExtensions()) + addInheritedProtocols(nominal, ext, source); +} + +std::pair +ConformanceLookupTable::isExtendedConformance(ProtocolDecl *proto) { + auto decl_loc = InheritedFrom[proto]; + if (std::get<0>(decl_loc).isNull()) return {nullptr, SourceLoc()}; + + ExtensionDecl *extension = std::get<0>(decl_loc).dyn_cast(); + ProtocolDecl *extendedProtocol = extension ? + extension->getExtendedProtocolDecl() : + dyn_cast(std::get<0>(decl_loc).get()); + + if (extension && extendedProtocol) + return {extension, std::get<1>(decl_loc)}; + + auto ext = isExtendedConformance(extendedProtocol); + InheritedFrom[proto] = ext; + return ext; } void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, @@ -504,7 +546,7 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal, } } - addInheritedProtocols(conformingProtocol, + addInheritedProtocols(nominal, conformingProtocol, ConformanceSource::forImplied(conformanceEntry)); } } @@ -661,7 +703,8 @@ ConformanceLookupTable::Ordering ConformanceLookupTable::compareConformances( return Ordering::Before; } auto module = lhs->getDeclContext()->getParentModule(); - assert(lhs->getDeclContext()->getParentModule() + assert(module->getASTContext().LangOpts.EnableConformingExtensions || + lhs->getDeclContext()->getParentModule() == rhs->getDeclContext()->getParentModule() && "conformances should be in the same module"); for (auto file : module->getFiles()) { @@ -1183,3 +1226,76 @@ void ConformanceLookupTable::dump(raw_ostream &os) const { } } +// Miscellaneous code added to implement conforming protocol extensions + +void ConformanceLookupTable::invalidate(NominalTypeDecl *nomimal, ProtocolDecl *proto) { + if (Updating) return; + LastProcessed.clear(); + InheritedFrom.clear(); + AllSupersededDiagnostics.clear(); +} + +void ExtensionDecl::setInherited(MutableArrayRef i) { + Inherited = i; + if (!Inherited.empty() && hasBeenBound()) + if (auto *proto = getExtendedProtocolDecl()) { + proto->inheritedProtocolsChanged(); + } +} + +void ProtocolDecl::inheritedProtocolsChanged() { + RequirementSignature = nullptr; + Bits.ProtocolDecl.InheritedProtocolsValid = false; + auto &extendeds = getASTContext().getExtendedConformances(); + auto protocol_nominals = extendeds.find(this); + if (protocol_nominals != extendeds.end()) + for (auto &nominal_extension : protocol_nominals->getSecond()) + if (NominalTypeDecl *nominal = nominal_extension.getFirst()) + nominal->prepareConformanceTable()->invalidate(nominal, this); +} + +void ProtocolDecl::recordExtendedNominal(NominalTypeDecl *nominal, ExtensionDecl *ext) { + assert(nominal); + getASTContext().getExtendedConformances()[this][nominal] = ext; +} + +void ASTContext::forEachExtendedConformance(ModuleDecl *module, + std::function emitWitness) { + SmallVector normals; + MultiConformanceChecker groupChecker(*this); + + for (auto &protocol_nominals : getExtendedConformances()) + for (auto &nominal_extensions : protocol_nominals.getSecond()) { + ProtocolDecl *proto = protocol_nominals.getFirst(); + NominalTypeDecl *nominal = nominal_extensions.getFirst(); + ExtensionDecl *extension = nominal_extensions.getSecond(); + + if (!extension) + continue; + if (isa(nominal)) + continue; + + SmallVector conformances; + nominal->prepareConformanceTable() + ->lookupConformance(module, nominal, proto, conformances); + + for (ProtocolConformance *conformance : conformances) + if (auto *normal = dyn_cast(conformance)) { + normals.push_back(normal); + if (!normal->isComplete()) + groupChecker.addConformance(normal); + normal->makePrivate(); + } + } + + if (groupChecker.checkAllConformances()) + exit(EXIT_FAILURE); + + for (NormalProtocolConformance *normal : normals) + if (normal->isComplete()) { + if (normal->isInUse() && !normal->isLazilyLoaded()) + emitWitness(normal); + } else + llvm::errs() << normal->getType()->getAnyNominal()->getName() << ": " << + normal->getProtocol()->getName().str() << " not complete\n"; +} diff --git a/lib/AST/ConformanceLookupTable.h b/lib/AST/ConformanceLookupTable.h index a27c89f03e063..b3e96d9ad1329 100644 --- a/lib/AST/ConformanceLookupTable.h +++ b/lib/AST/ConformanceLookupTable.h @@ -78,6 +78,9 @@ class ConformanceLookupTable { std::unordered_map> LastProcessed; + + /// Prevents invalidating while table is in the midst of being updated. + unsigned Updating = 0; struct ConformanceEntry; @@ -319,6 +322,11 @@ class ConformanceLookupTable { llvm::DenseMap> ConformingDeclMap; + typedef std::pair, SourceLoc> WhereFrom; + + /// Record of protocol and location a protocol is inferred from + llvm::DenseMap InheritedFrom; + /// Indicates whether we are visiting the superclass. bool VisitingSuperclass = false; @@ -327,9 +335,12 @@ class ConformanceLookupTable { ConformanceSource source); /// Add the protocols from the given list. - void addInheritedProtocols( - llvm::PointerUnion decl, - ConformanceSource source); + void addInheritedProtocols(NominalTypeDecl *nominal, + llvm::PointerUnion decl, + ConformanceSource source); + + /// Find any protocol extension in the chain of inhertance + std::pair isExtendedConformance(ProtocolDecl *proto); /// Expand the implied conformances for the given DeclContext. void expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc); @@ -413,6 +424,9 @@ class ConformanceLookupTable { /// Create a new conformance lookup table. ConformanceLookupTable(ASTContext &ctx); + /// Reset that conformance table has been processed so it will be recalculated + void invalidate(NominalTypeDecl *nomimal, ProtocolDecl *decl); + /// Destroy the conformance table. void destroy(); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 19fb3aca7c390..1415c705d8449 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -1195,12 +1195,12 @@ ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc, Decl(DeclKind::Extension, parent), IterableDeclContext(IterableDeclContextKind::ExtensionDecl), ExtensionLoc(extensionLoc), - ExtendedTypeRepr(extendedType), - Inherited(inherited) + ExtendedTypeRepr(extendedType) { Bits.ExtensionDecl.DefaultAndMaxAccessLevel = 0; Bits.ExtensionDecl.HasLazyConformances = false; setTrailingWhereClause(trailingWhereClause); + setInherited(inherited); } ExtensionDecl *ExtensionDecl::create(ASTContext &ctx, SourceLoc extensionLoc, diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index 08414cae27b50..15a11753a51dd 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -33,6 +33,7 @@ #include "swift/AST/TypeMatcher.h" #include "swift/AST/TypeRepr.h" #include "swift/AST/TypeWalker.h" +#include "swift/AST/NameLookup.h" #include "swift/Basic/Debug.h" #include "swift/Basic/Defer.h" #include "swift/Basic/Statistic.h" @@ -3855,6 +3856,23 @@ static ConstraintResult visitInherited( visitInherited(inheritedType, inherited.getTypeRepr()); } + // Protocol extensions with conformances. + if (auto *protoDecl = dyn_cast_or_null(typeDecl)) + for (auto *ext : protoDecl->getExtensions()) { + ArrayRef inheritedTypes = ext->getInherited(); + for (unsigned index : indices(inheritedTypes)) { + Type inheritedType + = evaluateOrDefault(evaluator, + InheritedTypeRequest{ext, index, + TypeResolutionStage::Structural}, + Type()); + if (!inheritedType) continue; + + const auto &inherited = inheritedTypes[index]; + visitInherited(inheritedType, inherited.getTypeRepr()); + } + } + return result; } @@ -4489,6 +4507,28 @@ ConstraintResult GenericSignatureBuilder::addTypeRequirement( anyErrors = true; } + // Protocol extensions with conformances. + if (auto *protoDecl = + dyn_cast_or_null(constraintType->getAnyNominal())) { + auto &conforms = resolvedSubject.getEquivalenceClass(*this)->conformsTo; + bool anyObject = false; + + for (auto *ext : protoDecl->getExtensions()) + for (const auto &found : + getDirectlyInheritedNominalTypeDecls(ext, anyObject)) + if (auto inheritedProtocol = dyn_cast(found.Item)) { + if (conforms.find(inheritedProtocol) != conforms.end()) + continue; + if (isErrorResult(addConformanceRequirement(resolvedSubject, + inheritedProtocol, + source))) + anyErrors = true; + else if (Type subjectType = resolvedSubject.getAsConcreteType()) + if (NominalTypeDecl *nominal = subjectType->getAnyNominal()) + inheritedProtocol->recordExtendedNominal(nominal, ext); + } + } + return anyErrors ? ConstraintResult::Conflicting : ConstraintResult::Resolved; } diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 5da65daf1d014..ff6ce03e6d0cc 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -2251,6 +2251,15 @@ InheritedProtocolsRequest::evaluate(Evaluator &evaluator, } } + // Protocol extensions with conformances. + for (ExtensionDecl *ext : PD->getExtensions()) + for (const auto &found : getDirectlyInheritedNominalTypeDecls(ext, anyObject)) { + if (auto proto = dyn_cast(found.Item)) { + if (known.insert(proto).second) + result.push_back(proto); + } + } + return PD->getASTContext().AllocateCopy(result); } @@ -2431,6 +2440,10 @@ swift::getDirectlyInheritedNominalTypeDecls( if (!req.getFirstType()->isEqual(protoSelfTy)) continue; + // don't duplicate extended conformances. + if (req.getModuleNumber()) + continue; + result.emplace_back(req.getSecondType()->castTo()->getDecl(), loc); } diff --git a/lib/AST/ProtocolConformance.cpp b/lib/AST/ProtocolConformance.cpp index d612dff6a0574..49eb7e2455bc6 100644 --- a/lib/AST/ProtocolConformance.cpp +++ b/lib/AST/ProtocolConformance.cpp @@ -75,6 +75,10 @@ ProtocolConformanceRef::ProtocolConformanceRef(ProtocolDecl *protocol, } } +ProtocolConformance *ProtocolConformanceRef::getConcrete() const { + return Union.get()->recordReferenced(); +} + ProtocolDecl *ProtocolConformanceRef::getRequirement() const { assert(!isInvalid()); @@ -925,8 +929,8 @@ void NormalProtocolConformance::setWitness(ValueDecl *requirement, assert(!isa(requirement) && "Request type witness"); assert(getProtocol() == cast(requirement->getDeclContext()) && "requirement in wrong protocol"); - assert(Mapping.count(requirement) == 0 && "Witness already known"); - assert((!isComplete() || isInvalid() || + assert(getProtocol()->getASTContext().LangOpts.EnableConformingExtensions || + (!isComplete() || isInvalid() || requirement->getAttrs().hasAttribute() || requirement->getAttrs().isUnavailable( requirement->getASTContext())) && @@ -1203,13 +1207,20 @@ ProtocolConformance * ProtocolConformance::getInheritedConformance(ProtocolDecl *protocol) const { auto result = getAssociatedConformance(getProtocol()->getSelfInterfaceType(), protocol); + if (!result.isConcrete()) { + // Late conformance through protocol extension in different module + SmallVector conformances; + NominalTypeDecl *nominal = getType()->getAnyNominal(); + if (nominal->lookupConformance(nullptr, protocol, conformances)) + result = ProtocolConformanceRef(conformances.front()); + } return result.isConcrete() ? result.getConcrete() : nullptr; } #pragma mark Protocol conformance lookup -void NominalTypeDecl::prepareConformanceTable() const { +ConformanceLookupTable *NominalTypeDecl::prepareConformanceTable() const { if (ConformanceTable) - return; + return ConformanceTable; auto mutableThis = const_cast(this); ASTContext &ctx = getASTContext(); @@ -1222,7 +1233,7 @@ void NominalTypeDecl::prepareConformanceTable() const { if (file->getKind() != FileUnitKind::Source && file->getKind() != FileUnitKind::ClangModule && file->getKind() != FileUnitKind::DWARFModule) { - return; + return ConformanceTable; } SmallPtrSet protocols; @@ -1256,6 +1267,8 @@ void NominalTypeDecl::prepareConformanceTable() const { addSynthesized(KnownProtocolKind::RawRepresentable); } } + + return ConformanceTable; } bool NominalTypeDecl::lookupConformance( @@ -1383,8 +1396,28 @@ IterableDeclContext::takeConformanceDiagnostics() const { return result; } - // Protocols are not subject to the checks for supersession. + // When protocol has been extended, collect any diagnostics from all nominals. if (isa(nominal)) { + llvm::SmallVector nominalsWithExtendedConformances; + for (auto &protocol_nominals : getASTContext().getExtendedConformances()) + for (auto &nominal_extension : protocol_nominals.getSecond()) { + ExtensionDecl *ext = nominal_extension.getSecond(); + if (ext != this) + continue; + nominalsWithExtendedConformances.push_back(nominal_extension.getFirst()); + } + + for (NominalTypeDecl *nominal : nominalsWithExtendedConformances) { + nominal->prepareConformanceTable(); + nominal->ConformanceTable->lookupConformances( + nominal, + nominal, + ConformanceLookupKind::All, + nullptr, + nullptr, + &result); + } + return result; } diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp index 232a69ac30487..af0dab0f32c74 100644 --- a/lib/Driver/ToolChains.cpp +++ b/lib/Driver/ToolChains.cpp @@ -249,6 +249,8 @@ void ToolChain::addCommonFrontendArgs(const OutputInfo &OI, inputArgs.AddLastArg(arguments, options::OPT_enable_astscope_lookup); inputArgs.AddLastArg(arguments, options::OPT_disable_astscope_lookup); inputArgs.AddLastArg(arguments, options::OPT_disable_parser_lookup); + inputArgs.AddLastArg(arguments, + options::OPT_enable_conforming_protocol_extensions); inputArgs.AddLastArg(arguments, options::OPT_enable_experimental_concise_pound_file); inputArgs.AddLastArg(arguments, diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 46bdf5c5f0d8d..e9903f2c6af4b 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -542,6 +542,9 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, Opts.OptimizationRemarkMissedPattern = generateOptimizationRemarkRegex(Diags, Args, A); + Opts.EnableConformingExtensions = + Args.hasArg(OPT_enable_conforming_protocol_extensions); + Opts.EnableConcisePoundFile = Args.hasArg(OPT_enable_experimental_concise_pound_file); diff --git a/lib/IRGen/GenProto.cpp b/lib/IRGen/GenProto.cpp index 2cb7ba801893d..479d13cb66569 100644 --- a/lib/IRGen/GenProto.cpp +++ b/lib/IRGen/GenProto.cpp @@ -36,6 +36,7 @@ #include "swift/AST/IRGenOptions.h" #include "swift/AST/PrettyStackTrace.h" #include "swift/AST/SubstitutionMap.h" +#include "swift/AST/DiagnosticsSema.h" #include "swift/ClangImporter/ClangModule.h" #include "swift/IRGen/Linking.h" #include "swift/SIL/SILDeclRef.h" @@ -1085,7 +1086,8 @@ mapConformanceIntoContext(IRGenModule &IGM, const RootProtocolConformance &conf, WitnessIndex ProtocolInfo::getAssociatedTypeIndex( IRGenModule &IGM, AssociatedType assocType) const { - assert(!IGM.isResilient(assocType.getSourceProtocol(), + assert(IGM.Context.LangOpts.EnableConformingExtensions || + !IGM.isResilient(assocType.getSourceProtocol(), ResilienceExpansion::Maximal) && "Cannot ask for the associated type index of non-resilient protocol"); for (auto &witness : getWitnessEntries()) { @@ -1954,6 +1956,8 @@ void IRGenModule::emitProtocolConformance( init.finishAndCreateFuture())); var->setConstant(true); setTrueConstGlobal(var); + if (record.conformance->isFromProtocolExtension()) + var->setLinkage(llvm::GlobalValue::InternalLinkage); } void IRGenerator::ensureRelativeSymbolCollocation(SILWitnessTable &wt) { @@ -1991,7 +1995,8 @@ void IRGenerator::ensureRelativeSymbolCollocation(SILDefaultWitnessTable &wt) { const ProtocolInfo &IRGenModule::getProtocolInfo(ProtocolDecl *protocol, ProtocolInfoKind kind) { // If the protocol is resilient, we cannot know the full witness table layout. - assert(!isResilient(protocol, ResilienceExpansion::Maximal) || + assert(Context.LangOpts.EnableConformingExtensions || + !isResilient(protocol, ResilienceExpansion::Maximal) || kind == ProtocolInfoKind::RequirementSignature); return Types.getProtocolInfo(protocol, kind); @@ -2160,6 +2165,9 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) { tableSize = wtableBuilder.getTableSize(); instantiationFunction = wtableBuilder.buildInstantiationFunction(); + + if (conf->isFromProtocolExtension()) + global->setLinkage(llvm::GlobalValue::InternalLinkage); } else { // Build the witness table. ResilientWitnessTableBuilder wtableBuilder(*this, wt); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 01a7a175dd4bb..830d7a164191a 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1884,6 +1884,13 @@ class SILGenModuleRAII { continue; SGM.visit(TD); } + +#if 01 + SGM.getASTContext().forEachExtendedConformance(SGM.SwiftModule, + [&](NormalProtocolConformance *normal) { + SGM.getWitnessTable(normal); + }); +#endif } explicit SILGenModuleRAII(SILModule &M) : SGM{M, M.getSwiftModule()} {} diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 75c2b1aa662fa..dd41c83178ab4 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -29,6 +29,7 @@ #include "swift/AST/SourceFile.h" #include "swift/AST/SubstitutionMap.h" #include "swift/AST/TypeMemberVisitor.h" +#include "swift/AST/DiagnosticsSema.h" #include "swift/SIL/FormalLinkage.h" #include "swift/SIL/PrettyStackTrace.h" #include "swift/SIL/SILArgument.h" @@ -665,6 +666,14 @@ SILFunction *SILGenModule::emitProtocolWitness( // Mapping from the requirement's generic signature to the witness // thunk's generic signature. auto reqtSubMap = witness.getRequirementToSyntheticSubs(); + if (reqtSubMap.empty()) { + if (auto conf = dyn_cast(conformance.getConcrete())) { + M.getASTContext().Diags. + diagnose(conf->getLoc(), diag::extension_protocol_limitation, + conf->getProtocol()->getName(), conf->getType()); + exit(EXIT_FAILURE); + } + } // The generic environment for the witness thunk. auto *genericEnv = witness.getSyntheticEnvironment(); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 934521b90670b..0368b6aab740a 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -705,8 +705,10 @@ bool AttributeChecker::visitAbstractAccessControlAttr( if (auto extension = dyn_cast(D)) { if (!extension->getInherited().empty()) { - diagnoseAndRemoveAttr(attr, diag::extension_access_with_conformances, - attr); + if (!extension->getExtendedProtocolDecl()) + diagnoseAndRemoveAttr(attr, diag::extension_access_with_conformances, attr); + else if (!(attr->getAccess() == AccessLevel::Public)) + diagnoseAndRemoveAttr(attr, diag::protocol_extension_access_with_conformances); return true; } } diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 0b3c87267419f..77494ae0e44b9 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -772,6 +772,29 @@ RequirementSignatureRequest::evaluate(Evaluator &evaluator, SmallVector requirements; contextData->loader->loadRequirementSignature( proto, contextData->requirementSignatureData, requirements); + + if (!proto->isResilient()) { + // Tack on conformances from protocol extensions in modules + bool anyObject = false; + unsigned moduleNumber = 0; + ModuleDecl *currentModule = proto->getParentModule(); + for (auto *ext : proto->getExtensions()) { + if (currentModule != ext->getParentModule()) { + currentModule = ext->getParentModule(); + ++moduleNumber; + } + for (const auto &found : + getDirectlyInheritedNominalTypeDecls(ext, anyObject)) + if (!llvm::count_if(requirements, [&](Requirement req) { + return req.getSecondType()->getAnyNominal() == found.Item; + })) + requirements.push_back(Requirement(RequirementKind::Conformance, + proto->getSelfInterfaceType(), + found.Item->getDeclaredType(), + moduleNumber)); + } + } + if (requirements.empty()) return None; return ctx.AllocateCopy(requirements); diff --git a/lib/Sema/TypeCheckDeclPrimary.cpp b/lib/Sema/TypeCheckDeclPrimary.cpp index 912161abf0f84..79e41fd41cba9 100644 --- a/lib/Sema/TypeCheckDeclPrimary.cpp +++ b/lib/Sema/TypeCheckDeclPrimary.cpp @@ -85,11 +85,19 @@ static void checkInheritanceClause( // Protocol extensions cannot have inheritance clauses. if (auto proto = ext->getExtendedProtocolDecl()) { if (!inheritedClause.empty()) { - ext->diagnose(diag::extension_protocol_inheritance, - proto->getName()) - .highlight(SourceRange(inheritedClause.front().getSourceRange().Start, - inheritedClause.back().getSourceRange().End)); - return; + // Force recalculation conformance table + ext->setInherited(inheritedClause); + + auto *attr = ext->getAttrs().getAttribute(); + if (!attr || attr->getAccess() < AccessLevel::Public) + ext->diagnose(diag::protocol_extension_access_with_conformances) + .fixItInsert(ext->getStartLoc(), "public "); + + // error unless comforming extensions is enabled. + if (!ext->getASTContext().LangOpts.EnableConformingExtensions) + ext->diagnose(diag::extension_protocol_inheritance, proto->getName()) + .highlight(SourceRange(inheritedClause.front().getSourceRange().Start, + inheritedClause.back().getSourceRange().End)); } } } else { diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index b3b7a9dc8a89c..62585a7bdbe4b 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -1456,48 +1456,6 @@ RequirementCheck WitnessChecker::checkWitness(ValueDecl *requirement, # pragma mark Witness resolution -/// This is a wrapper of multiple instances of ConformanceChecker to allow us -/// to diagnose and fix code from a more global perspective; for instance, -/// having this wrapper can help issue a fixit that inserts protocol stubs from -/// multiple protocols under checking. -class swift::MultiConformanceChecker { - ASTContext &Context; - llvm::SmallVector UnsatisfiedReqs; - llvm::SmallVector AllUsedCheckers; - llvm::SmallVector AllConformances; - llvm::SetVector MissingWitnesses; - llvm::SmallPtrSet CoveredMembers; - - /// Check one conformance. - ProtocolConformance * checkIndividualConformance( - NormalProtocolConformance *conformance, bool issueFixit); - - /// Determine whether the given requirement was left unsatisfied. - bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); -public: - MultiConformanceChecker(ASTContext &ctx) : Context(ctx) {} - - ASTContext &getASTContext() const { return Context; } - - /// Add a conformance into the batched checker. - void addConformance(NormalProtocolConformance *conformance) { - AllConformances.push_back(conformance); - } - - /// Peek the unsatisfied requirements collected during conformance checking. - ArrayRef getUnsatisfiedRequirements() { - return llvm::makeArrayRef(UnsatisfiedReqs); - } - - /// Whether this member is "covered" by one of the conformances. - bool isCoveredMember(ValueDecl *member) const { - return CoveredMembers.count(member) > 0; - } - - /// Check all conformances and emit diagnosis globally. - void checkAllConformances(); -}; - bool MultiConformanceChecker:: isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req) { if (conformance->isInvalid()) return false; @@ -1523,7 +1481,7 @@ isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req) { return false; } -void MultiConformanceChecker::checkAllConformances() { +bool MultiConformanceChecker::checkAllConformances() { bool anyInvalid = false; for (unsigned I = 0, N = AllConformances.size(); I < N; ++I) { auto *conformance = AllConformances[I]; @@ -1548,7 +1506,7 @@ void MultiConformanceChecker::checkAllConformances() { } // If all missing witnesses are issued with fixits, we are done. if (MissingWitnesses.empty()) - return; + return anyInvalid; // Otherwise, backtrack to the last checker that has missing witnesses // and diagnose missing witnesses from there. @@ -1558,6 +1516,8 @@ void MultiConformanceChecker::checkAllConformances() { It->diagnoseMissingWitnesses(MissingWitnessDiagnosisKind::FixItOnly); } } + + return true; } static void diagnoseConformanceImpliedByConditionalConformance( diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 5149fbf6c5a8f..1d43dc89d96c2 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -962,4 +962,45 @@ void diagnoseConformanceFailure(Type T, } +/// This is a wrapper of multiple instances of ConformanceChecker to allow us +/// to diagnose and fix code from a more global perspective; for instance, +/// having this wrapper can help issue a fixit that inserts protocol stubs from +/// multiple protocols under checking. +class swift::MultiConformanceChecker { + ASTContext &Context; + llvm::SmallVector UnsatisfiedReqs; + llvm::SmallVector AllUsedCheckers; + llvm::SmallVector AllConformances; + llvm::SetVector MissingWitnesses; + llvm::SmallPtrSet CoveredMembers; + + /// Check one conformance. + ProtocolConformance * checkIndividualConformance( + NormalProtocolConformance *conformance, bool issueFixit); + + /// Determine whether the given requirement was left unsatisfied. + bool isUnsatisfiedReq(NormalProtocolConformance *conformance, ValueDecl *req); +public: + MultiConformanceChecker(ASTContext &ctx) : Context(ctx) {} + + ASTContext &getASTContext() const { return Context; } + + /// Add a conformance into the batched checker. + void addConformance(NormalProtocolConformance *conformance) { + AllConformances.push_back(conformance); + } + + /// Peek the unsatisfied requirements collected during conformance checking. + ArrayRef getUnsatisfiedRequirements() { + return llvm::makeArrayRef(UnsatisfiedReqs); + } + + /// Whether this member is "covered" by one of the conformances. + bool isCoveredMember(ValueDecl *member) const { + return CoveredMembers.count(member) > 0; + } + + /// Check all conformances and emit diagnosis globally. + bool checkAllConformances(); +}; #endif // SWIFT_SEMA_PROTOCOL_H diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp index 230adc8ca4e47..421ab0c3cb624 100644 --- a/lib/Sema/TypeChecker.cpp +++ b/lib/Sema/TypeChecker.cpp @@ -349,6 +349,14 @@ TypeCheckSourceFileRequest::evaluate(Evaluator &eval, SourceFile *SF) const { } typeCheckDelayedFunctions(*SF); + + // re-typecheck protocol extensions to diagnose any redundant conformances + if (SF->getASTContext().LangOpts.EnableConformingExtensions) { + for (auto D : SF->getTopLevelDecls()) + if (auto *ED = dyn_cast(D)) + if (ED->getExtendedProtocolDecl() && ED->getInherited().size()) + TypeChecker::typeCheckDecl(ED); + } } // Check to see if there's any inconsistent @_implementationOnly imports. diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 028919584dfd6..58a344b4a61ac 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -6054,6 +6054,7 @@ void ModuleFile::finishNormalConformance(NormalProtocolConformance *conformance, auto isConformanceReq = [](const Requirement &req) { return req.getKind() == RequirementKind::Conformance; }; +#if 0 if (conformanceCount != llvm::count_if(proto->getRequirementSignature(), isConformanceReq)) { fatal(llvm::make_error( @@ -6062,6 +6063,28 @@ void ModuleFile::finishNormalConformance(NormalProtocolConformance *conformance, } while (conformanceCount--) reqConformances.push_back(readConformance(DeclTypeCursor)); +#else + while (conformanceCount--) + reqConformances.push_back(readConformance(DeclTypeCursor)); + + // Tack on conformances due to protocol extensions. + SmallVector signatureConformances; + llvm::copy_if(proto->getRequirementSignature(), + std::back_inserter(signatureConformances), isConformanceReq); + for (unsigned extra = reqConformances.size(); + extra < signatureConformances.size(); extra++) + if (ProtocolDecl *second = dyn_cast(signatureConformances + [extra].getSecondType()->getAnyNominal())) + reqConformances.push_back(ProtocolConformanceRef(second)); + else + llvm::errs() << "Missed protocol extension conformance\n"; + + if (reqConformances.size() != signatureConformances.size()) { + fatal(llvm::make_error( + "serialized conformances do not match requirement signature", + llvm::inconvertibleErrorCode())); + } +#endif } conformance->setSignatureConformances(reqConformances); @@ -6100,6 +6123,8 @@ void ModuleFile::finishNormalConformance(NormalProtocolConformance *conformance, // In this situation we need to do a post-pass to fill in missing // requirements with opaque witnesses. bool needToFillInOpaqueValueWitnesses = false; + auto deserizeValueWitnesses = [&](unsigned valueCount, + bool deserializeSyntheticSubs) { while (valueCount--) { ValueDecl *req; @@ -6165,6 +6190,24 @@ void ModuleFile::finishNormalConformance(NormalProtocolConformance *conformance, fatal(witnessSubstitutions.takeError()); } + SubstitutionMap reqToSyntheticEnvSubs; + if (deserializeSyntheticSubs) { + auto reqSubstitutionsMaybe = getSubstitutionMapChecked(*rawIDIter++); + if (!reqSubstitutionsMaybe) { + // Missing module errors are most likely caused by an + // implementation-only import hiding types and decls. + // rdar://problem/52837313 + if (reqSubstitutionsMaybe.errorIsA()) { + consumeError(reqSubstitutionsMaybe.takeError()); + isOpaque = true; + } + else + fatal(reqSubstitutionsMaybe.takeError()); + } + else + reqToSyntheticEnvSubs = reqSubstitutionsMaybe.get(); + } + // Handle opaque witnesses that couldn't be deserialized. if (isOpaque) { trySetOpaqueWitness(); @@ -6172,10 +6215,18 @@ void ModuleFile::finishNormalConformance(NormalProtocolConformance *conformance, } // Set the witness. - trySetWitness(Witness::forDeserialized(witness, witnessSubstitutions.get())); + trySetWitness(Witness(witness, witnessSubstitutions.get(), + nullptr, reqToSyntheticEnvSubs)); } assert(rawIDIter <= rawIDs.end() && "read too much"); - + }; + + deserizeValueWitnesses(valueCount, /*deserializeSyntheticSubs*/false); + + // If RequirementToSyntheticSubs have been serialized, deserialize again + if (rawIDIter < rawIDs.end()) + deserizeValueWitnesses(valueCount, /*deserializeSyntheticSubs*/true); + // Fill in opaque value witnesses if we need to. if (needToFillInOpaqueValueWitnesses) { for (auto member : proto->getMembers()) { diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 445399bc36c90..7d579b7e5b2de 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -1339,8 +1339,10 @@ void Serializer::writeASTBlockEntity( return false; }); - conformance->forEachValueWitness([&](ValueDecl *req, Witness witness) { - ++numValueWitnesses; + bool serializeSyntheticSubs = false; + auto seriaiseValueWitnesses = [&](ValueDecl *req, Witness witness) { + if (!serializeSyntheticSubs) + ++numValueWitnesses; data.push_back(addDeclRef(req)); data.push_back(addDeclRef(witness.getDecl())); assert(witness.getDecl() || req->getAttrs().hasAttribute() @@ -1364,7 +1366,32 @@ void Serializer::writeASTBlockEntity( subs = subs.mapReplacementTypesOutOfContext(); data.push_back(addSubstitutionMapRef(subs)); - }); + + if (!serializeSyntheticSubs) + return; + + auto syntheticSubs = witness.getRequirementToSyntheticSubs(); + + // Canonicalize8away typealiases, since these substitutions aren't used + // for diagnostics and we reference fewer declarations that way. + syntheticSubs = syntheticSubs.getCanonical(); + + // Map archetypes to type parameters, since we always substitute them + // away. Note that in a merge-modules pass, we're serializing conformances + // that we deserialized, so they will already have their replacement types + // in terms of interface types; hence the hasArchetypes() check is + // necessary for correctness, not just as a fast path. + if (syntheticSubs.hasArchetypes()) + syntheticSubs = syntheticSubs.mapReplacementTypesOutOfContext(); + + data.push_back(addSubstitutionMapRef(syntheticSubs)); + }; + + conformance->forEachValueWitness(seriaiseValueWitnesses); + + // serialize RequirementToSyntheticSubs to allow conforming protocl etensions + serializeSyntheticSubs = true; + conformance->forEachValueWitness(seriaiseValueWitnesses); unsigned numSignatureConformances = conformance->getSignatureConformances().size(); @@ -3241,6 +3268,14 @@ class Serializer::DeclSerializer : public DeclVisitor { dependencyTypes.insert(element.getType()); } + for (ExtensionDecl *ED : const_cast(proto)->getExtensions()) + for (auto element : ED->getInherited()) { + assert(!element.getType()->hasArchetype()); + inheritedAndDependencyTypes.push_back(S.addTypeRef(element.getType())); + if (element.getType()->is()) + dependencyTypes.insert(element.getType()); + } + for (Requirement req : proto->getRequirementSignature()) { // Requirements can be cyclic, so for now filter out any requirements // from elsewhere in the module. This isn't perfect---something else in diff --git a/test/decl/conforming_extensions.swift b/test/decl/conforming_extensions.swift new file mode 100644 index 0000000000000..a2bee21e5735f --- /dev/null +++ b/test/decl/conforming_extensions.swift @@ -0,0 +1,114 @@ +// RUN: %target-build-swift -enable-conforming-protocol-extensions %s -o %t.out +// RUN: %target-run %t.out | %FileCheck %s + +public extension FixedWidthInteger: ExpressibleByUnicodeScalarLiteral { + @_transparent + init(unicodeScalarLiteral value: Unicode.Scalar) { + self = Self(value.value) + } + func foo() -> String { + return "Foo!" + } +} + +let a: [Int16] = ["a", "b", "c", "d", "e"] +let b: [UInt32] = ["a", "b", "c", "d", "e"] + +protocol Q { + var bar: Double { get } + func foo2() -> String +} +public protocol P { +} +public extension P: Q { + var bar: Double { return -888 } + func foo2() -> String { + return "Foo2 \(self)!" + } +} +class C {} +struct S { + var a = 99 +} +public protocol P2 : P {} +public extension P2 {} + + extension C: P2 {} + extension S: P2 {} + +//public extension Numeric: P2 {} +public extension FixedWidthInteger: P2 {} + +// CHECK: Foo2 main.C! +print(C().foo2()) +// CHECK: Foo2 S(a: 99)! +print(S().foo2()) +// CHECK: Foo2 99! +print(Int8(99).foo2()) +// CHECK: Foo2 99! +//extension UInt32: P2 {} +print(UInt32(99).foo2()) +// CHECK: Foo! +print(Int8(99).foo()) +// CHECK: Foo! +print(Int32(99).foo()) +// CHECK: Foo! +print(UInt32(99).foo()) + +// CHECK: [97, 98, 99, 100, 101] +print(a) +// CHECK: [97, 98, 99, 100, 101] +print(b) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(a.map {$0.foo2()}) +// CHECK: ["Foo!", "Foo!", "Foo!", "Foo!", "Foo!"] +print(b.map {$0.foo()}) +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(b.map {$0.foo2()}) + +public func use(_ value: T) -> T + where T : FixedWidthInteger { + print(value.foo2()) + return value + "1" // ← Used ExpressibleByUnicodeScalarLiteral +} + +// CHECK: 50 +print(use(1)) + +let u = UInt32(99) + +// CHECK: 148 +print(use(u)) + +func aaa(_ b: P2) { + print(b.foo2()) +} + +aaa(u) + +//aaa(88.0) +// +//print(99.0.foo2()) + +let v = b.map { $0 as P2 }.map { $0.foo2() } +// CHECK: ["Foo2 97!", "Foo2 98!", "Foo2 99!", "Foo2 100!", "Foo2 101!"] +print(v) + +// CHECK: -888.0 +print(99.bar) + +import Foundation + +public protocol CC { +} + +public struct AA: CC { + var a = 9 + var b = 8 +} + +public extension CC: Codable { +} + +// CHECK: {"a":9,"b":8} +print(String(data: try! JSONEncoder().encode(AA()), encoding: .utf8) ?? "?") diff --git a/test/decl/ext/protocol.swift b/test/decl/ext/protocol.swift index 860bcbcfbc5ba..5513778724cb8 100644 --- a/test/decl/ext/protocol.swift +++ b/test/decl/ext/protocol.swift @@ -947,7 +947,7 @@ enum Foo : Int, ReallyRaw { protocol BadProto1 { } protocol BadProto2 { } -extension BadProto1 : BadProto2 { } // expected-error{{extension of protocol 'BadProto1' cannot have an inheritance clause}} +extension BadProto1 : BadProto2 { } // expected-error{{protocol extensions with conformances must currently be public}} expected-error{{inheritance clause in extension of protocol 'BadProto1'. use -enable-conforming-protocol-extensions to enable this feature}} extension BadProto2 { struct S { } // expected-error{{type 'S' cannot be nested in protocol extension of 'BadProto2'}}