Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@ class ASTContext final {
llvm::BumpPtrAllocator &
getAllocator(AllocationArena arena = AllocationArena::Permanent) const;

/// Track protocol extensions that inherit to determine witness tables to emit.
mutable SmallVector<ExtensionDecl *, 2> ExtensionsWithConformances;

/// Protocol conformances that are a result of a protocol extension
mutable llvm::DenseMap<ProtocolDecl *,
llvm::DenseMap<ProtocolDecl *, ExtensionDecl *>> ExtendedConformances;

friend class ProtocolDecl;

public:
/// Allocate - Allocate memory from the ASTContext bump pointer.
void *Allocate(unsigned long bytes, unsigned alignment,
Expand Down Expand Up @@ -716,6 +725,19 @@ class ASTContext final {
/// one.
void loadExtensions(NominalTypeDecl *nominal, unsigned previousGeneration);

/// Keep track of protocol extensions that include a conformance.
///
/// \param extWithConformances A (protocol) extension that has conformances
void addExtensionWithConformances(ExtensionDecl *extWithConformances) {
ExtensionsWithConformances.emplace_back(extWithConformances);
}

/// Iterate over conformances arising from protocol extensions.
///
/// \param emitWitness A function to call to emit a witness table.
void forEachExtendedConformance(
std::function<void (NormalProtocolConformance *)> emitWitness);

/// Load the methods within the given class that produce
/// Objective-C class or instance methods with the given selector.
///
Expand Down
14 changes: 10 additions & 4 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
MutableArrayRef<TypeLoc> getInherited() { return Inherited; }
ArrayRef<TypeLoc> getInherited() const { return Inherited; }

void setInherited(MutableArrayRef<TypeLoc> i) { Inherited = i; }
void setInherited(MutableArrayRef<TypeLoc> i);

bool hasDefaultAccessLevel() const {
return Bits.ExtensionDecl.DefaultAndMaxAccessLevel != 0;
Expand Down Expand Up @@ -3338,9 +3338,6 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
/// a given nominal type.
mutable ConformanceLookupTable *ConformanceTable = nullptr;

/// Prepare the conformance table.
void prepareConformanceTable() const;

/// Returns the protocol requirements that \c Member conforms to.
ArrayRef<ValueDecl *>
getSatisfiedProtocolRequirementsForMember(const ValueDecl *Member,
Expand Down Expand Up @@ -3450,6 +3447,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
/// conform, such as AnyObject (for classes).
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);

/// Prepare the conformance table (also acts as accessor).
ConformanceLookupTable *prepareConformanceTable() const;

/// Look for conformances of this nominal type to the given
/// protocol.
///
Expand Down Expand Up @@ -4277,6 +4277,12 @@ class ProtocolDecl final : public NominalTypeDecl {
return const_cast<ProtocolDecl *>(this)->getInheritedProtocolsSlow();
}

/// An extension has inherited a new protocol
void inheritedProtocolsChanged();

//// Has a conformance come about from an extension?
bool isExtendedConformance(ProtocolDecl *proto);

/// Determine whether this protocol has a superclass.
bool hasSuperclass() const { return (bool)getSuperclassDecl(); }

Expand Down
8 changes: 6 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,8 @@ NOTE(composition_in_extended_type_alternative,none,
ERROR(extension_access_with_conformances,none,
"%0 modifier cannot be used with extensions that declare "
"protocol conformances", (DeclAttribute))
ERROR(protocol_extension_access_with_conformances,none,
"protocol extensions with conformances must currently be public", ())
ERROR(extension_metatype,none,
"cannot extend a metatype %0", (Type))
ERROR(extension_specialization,none,
Expand All @@ -1723,8 +1725,10 @@ NOTE(extension_stored_property_fixit,none,
ERROR(extension_nongeneric_trailing_where,none,
"trailing 'where' clause for extension of non-generic type %0",
(DeclName))
ERROR(extension_protocol_inheritance,none,
"extension of protocol %0 cannot have an inheritance clause", (DeclName))
WARNING(extension_protocol_inheritance,none,
"inheritance clause in extension of protocol %0. "
"use -enable-conforming-protocol-extensions to remove this warning",
(DeclName))
ERROR(objc_generic_extension_using_type_parameter,none,
"extension of a generic Objective-C class cannot access the class's "
"generic parameters at runtime", ())
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ namespace swift {
/// Detect and automatically import modules' cross-import overlays.
bool EnableCrossImportOverlays = false;

/// Gate for conforming protocol extensions code.
bool EnableConformingExtensions = false;

///
/// Support for alternate usage modes
///
Expand Down
5 changes: 5 additions & 0 deletions include/swift/Option/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ def enable_experimental_concise_pound_file : Flag<["-"],
Flags<[FrontendOption]>,
HelpText<"Enable experimental concise '#file' identifier">;

def enable_conforming_protocol_extensions : Flag<["-"],
"enable-conforming-protocol-extensions">,
Flags<[FrontendOption]>,
HelpText<"Enable experimental feature to allow protocol extensions to conform to protocols">;

// Diagnostic control options
def suppress_warnings : Flag<["-"], "suppress-warnings">,
Flags<[FrontendOption]>,
Expand Down
18 changes: 15 additions & 3 deletions include/swift/SIL/SILWitnessVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {

public:
void visitProtocolDecl(ProtocolDecl *protocol) {
// This chicanery is to move conformances arising from
// protocol extensions to the end of the witness table.
visitProtocolDecl(protocol, false);
visitProtocolDecl(protocol, true);
}

void visitProtocolDecl(ProtocolDecl *protocol, bool includeExtended) {
// The protocol conformance descriptor gets added first.
if (!includeExtended)
asDerived().addProtocolConformanceDescriptor();

for (const auto &reqt : protocol->getRequirementSignature()) {
Expand Down Expand Up @@ -79,12 +87,14 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
assert(type->isEqual(protocol->getSelfInterfaceType()));
assert(parameter->getDepth() == 0 && parameter->getIndex() == 0 &&
"non-self type parameter in protocol");
asDerived().addOutOfLineBaseProtocol(requirement);
if (protocol->isExtendedConformance(requirement) == includeExtended)
asDerived().addOutOfLineBaseProtocol(requirement);
continue;
}

// Otherwise, add an associated requirement.
AssociatedConformance assocConf(protocol, type, requirement);
if (!includeExtended)
asDerived().addAssociatedConformance(assocConf);
continue;
}
Expand All @@ -93,17 +103,19 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
}

// Add the associated types.
if (!includeExtended)
for (auto *associatedType : protocol->getAssociatedTypeMembers()) {
// If this is a new associated type (which does not override an
// existing associated type), add it.
if (associatedType->getOverriddenDecls().empty())
asDerived().addAssociatedType(AssociatedType(associatedType));
}

if (asDerived().shouldVisitRequirementSignatureOnly())
return;
// if (asDerived().shouldVisitRequirementSignatureOnly())
// return;

// Visit the witnesses for the direct members of a protocol.
if (!includeExtended)
for (Decl *member : protocol->getMembers()) {
ASTVisitor<T>::visit(member);
}
Expand Down
3 changes: 2 additions & 1 deletion lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2620,7 +2620,8 @@ class Verifier : public ASTWalker {
}

auto proto = conformance->getProtocol();
if (normal->getDeclContext() != conformingDC) {
if (normal->getDeclContext() != conformingDC &&
!isa<ExtensionDecl>(conformingDC)) {
Out << "AST verification error: conformance of "
<< nominal->getName().str() << " to protocol "
<< proto->getName().str() << " is in the wrong context.\n"
Expand Down
137 changes: 135 additions & 2 deletions lib/AST/ConformanceLookupTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/ProtocolConformanceRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "../Sema/TypeCheckProtocol.h"

using namespace swift;

Expand Down Expand Up @@ -475,11 +476,60 @@ void ConformanceLookupTable::addInheritedProtocols(
}
}

void ConformanceLookupTable::addImpliedProtocols(
NominalTypeDecl *nominal,
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
ConformanceSource source,
llvm::DenseMap<NominalTypeDecl *, NominalTypeDecl *> &Seen,
Propagator *propagator) {
bool anyObject = false;
ExtensionDecl *extension = decl.dyn_cast<ExtensionDecl *>();
NominalTypeDecl *extending = extension ? extension->getExtendedNominal() :
dyn_cast<NominalTypeDecl>(decl.dyn_cast<TypeDecl *>());

// Closure to register inherited protocols against extending.
Propagator registerWitness = [&](ProtocolDecl *inheritedProto) {
if (extension && extending) {
auto *table = extending->prepareConformanceTable();
table->NominalConformancesFromExtension[extension][nominal][inheritedProto] = true;
}
// Continue propagating up stack.
if (propagator)
(*propagator)(inheritedProto);
};

// Find all of the protocols in the inheritance list.
unsigned count = 0, expected = extension ? extension->getInherited().size() : 0;
for (const auto &found :
getDirectlyInheritedNominalTypeDecls(decl, anyObject)) {
if (Seen.find(found.Item) == Seen.end()) {
Seen[found.Item] = extending; // prevent circularity (diagnose somewhere?)
addImpliedProtocols(nominal, found.Item, source, Seen, &registerWitness);
}
count++;
}
assert(!(count < expected));

if (!extension && extending) {
for (ExtensionDecl *ext : extending->getExtensions())
addImpliedProtocols(nominal, ext, source, Seen, &registerWitness);

if (auto proto = dyn_cast<ProtocolDecl>(extending)) {
addProtocol(proto, extension ? extension->getLoc() : extending->getLoc(), source);
if (propagator)
(*propagator)(/*inheritedProto*/proto);
}
}
}

void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal,
DeclContext *dc) {
// Note: recursive type-checking implies that AllConformances
// may be reallocated during this traversal, so pay the lookup cost
// during each iteration.
llvm::DenseMap<NominalTypeDecl *, NominalTypeDecl *> Seen;
unsigned topLevel = AllConformances[dc].size();

for (unsigned i = 0; i != AllConformances[dc].size(); ++i) {
/// FIXME: Avoid the possibility of an infinite loop by fixing the root
/// cause instead (incomplete circularity detection).
Expand All @@ -504,8 +554,9 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal,
}
}

addInheritedProtocols(conformingProtocol,
ConformanceSource::forImplied(conformanceEntry));
if (i < topLevel)
addImpliedProtocols(nominal, conformingProtocol,
ConformanceSource::forImplied(conformanceEntry), Seen);
}
}

Expand Down Expand Up @@ -1183,3 +1234,85 @@ void ConformanceLookupTable::dump(raw_ostream &os) const {
}
}

// Miscellaneous code added to implement conforming protocol extensions

void ConformanceLookupTable::addExtendedConformances(const ExtensionDecl *ext,
SmallVectorImpl<ProtocolConformance *> &conformances) {
MultiConformanceChecker groupChecker(ext->getASTContext());
for (auto &ext : NominalConformancesFromExtension)
for (auto &nominalPair : ext.second) {
NominalTypeDecl *nominal = nominalPair.first;
if (nominal->getSelfProtocolDecl())
continue;
for (auto &protocolPair : nominalPair.second) {
ProtocolDecl *proto = protocolPair.first;
auto table = nominal->prepareConformanceTable();
auto entry = table->Conformances.find(proto);
if (entry == table->Conformances.end())
continue;

auto conformance = table->getConformance(nominal, entry->second.back());
if (auto normal = dyn_cast<NormalProtocolConformance>(conformance)) {
conformances.push_back(conformance);
groupChecker.addConformance(normal);
}
}
}

if (groupChecker.checkAllConformances())
exit(EXIT_FAILURE);
}

void ConformanceLookupTable::invalidate(NominalTypeDecl *nomimal) {
for (auto &extInfo : NominalConformancesFromExtension)
for (auto &toInvalidate : extInfo.second) {
NominalTypeDecl *nominal = toInvalidate.first;
if (auto *proto = dyn_cast<ProtocolDecl>(nominal))
proto->inheritedProtocolsChanged();
else
nominal->prepareConformanceTable()->invalidate(nominal);
}

LastProcessed.clear();
}

void ProtocolDecl::inheritedProtocolsChanged() {
Bits.ProtocolDecl.InheritedProtocolsValid = false;
prepareConformanceTable()->invalidate(this);
RequirementSignature = nullptr;
}

bool ProtocolDecl::isExtendedConformance(ProtocolDecl *proto) {
if (!Bits.ProtocolDecl.InheritedProtocolsValid)
(void)getInheritedProtocolsSlow();
auto &extendedProtocols = getASTContext().ExtendedConformances;
auto extendedProto = extendedProtocols.find(this);
return extendedProto != extendedProtocols.end() &&
extendedProto->second.find(proto) != extendedProto->second.end();
}

void ExtensionDecl::setInherited(MutableArrayRef<TypeLoc> i) {
Inherited = i;
if (!Inherited.empty()) {
getASTContext().addExtensionWithConformances(this);

if (hasBeenBound())
if (auto *proto = getExtendedProtocolDecl())
proto->inheritedProtocolsChanged();
}
}

void ASTContext::forEachExtendedConformance(
std::function<void (NormalProtocolConformance *)> emitWitness) {
for (ExtensionDecl *ext : ExtensionsWithConformances)
if (ext->hasBeenBound())
if (ProtocolDecl *proto = ext->getExtendedProtocolDecl()) {
SmallVector<ProtocolConformance *, 2> result;
proto->prepareConformanceTable()->addExtendedConformances(ext, result);
for (auto conformance : result)
if (auto *normal = dyn_cast<NormalProtocolConformance>(conformance))
if (!conformance->getType()->getCanonicalType()
->getAnyNominal()->getSelfProtocolDecl())
emitWitness(normal);
}
}
Loading