From 28c1178c4f6b6d81953ca7b889e55992a81b51d8 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 14:31:09 -0500 Subject: [PATCH 1/9] RequirementMachine: Introduce TypeAliasRequirementsRequest This is a verbatim copy of the GenericSignatureBuilder's somewhat questionable (but necessary for source compatibility) logic where protocol typealiases with the same name as some other associated type imply a same-type requirement. The related diagnostics are there too, but only emitted when -requirement-machine-protocol-signatures=on; in 'verify' mode, the GSB will emit the same diagnostics. --- include/swift/AST/Decl.h | 6 + include/swift/AST/TypeCheckRequests.h | 19 ++ include/swift/AST/TypeCheckerTypeIDZone.def | 3 + lib/AST/Decl.cpp | 7 + .../RequirementLowering.cpp | 263 +++++++++++++++++- ...ocol_typealias_same_type_requirement.swift | 45 +++ 6 files changed, 339 insertions(+), 4 deletions(-) create mode 100644 test/Generics/protocol_typealias_same_type_requirement.swift diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index cbd6503dd9dee..1f402f959eed7 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -4225,6 +4225,7 @@ class ProtocolDecl final : public NominalTypeDecl { friend class SuperclassDeclRequest; friend class SuperclassTypeRequest; friend class StructuralRequirementsRequest; + friend class TypeAliasRequirementsRequest; friend class ProtocolDependenciesRequest; friend class RequirementSignatureRequest; friend class RequirementSignatureRequestRQM; @@ -4421,6 +4422,11 @@ class ProtocolDecl final : public NominalTypeDecl { /// instead. ArrayRef getStructuralRequirements() const; + /// Retrieve same-type requirements implied by protocol typealiases with the + /// same name as associated types, and diagnose cases that are better expressed + /// via a 'where' clause. + ArrayRef getTypeAliasRequirements() const; + /// Get the list of protocols appearing on the right hand side of conformance /// requirements. Computed from the structural requirements, above. ArrayRef getProtocolDependencies() const; diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 3a45e89f2936d..532133226e3c2 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -387,6 +387,25 @@ class StructuralRequirementsRequest : bool isCached() const { return true; } }; +class TypeAliasRequirementsRequest : + public SimpleRequest(ProtocolDecl *), + RequestFlags::Cached> { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + ArrayRef + evaluate(Evaluator &evaluator, ProtocolDecl *proto) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + class ProtocolDependenciesRequest : public SimpleRequest(ProtocolDecl *), diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 666fad48708a4..b34e85806f03f 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -227,6 +227,9 @@ SWIFT_REQUEST(TypeChecker, RequirementRequest, SWIFT_REQUEST(TypeChecker, StructuralRequirementsRequest, ArrayRef(ProtocolDecl *), Cached, HasNearestLocation) +SWIFT_REQUEST(TypeChecker, TypeAliasRequirementsRequest, + ArrayRef(ProtocolDecl *), Cached, + HasNearestLocation) SWIFT_REQUEST(TypeChecker, ProtocolDependenciesRequest, ArrayRef(ProtocolDecl *), Cached, HasNearestLocation) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 36db5d2aafdf5..32266af3cab53 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -5260,6 +5260,13 @@ ProtocolDecl::getStructuralRequirements() const { None); } +ArrayRef +ProtocolDecl::getTypeAliasRequirements() const { + return evaluateOrDefault(getASTContext().evaluator, + TypeAliasRequirementsRequest { const_cast(this) }, + None); +} + ArrayRef ProtocolDecl::getProtocolDependencies() const { return evaluateOrDefault(getASTContext().evaluator, diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index c1cb623e7a9a0..a06c67840b9d8 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -25,6 +25,7 @@ #include "RequirementLowering.h" #include "swift/AST/ASTContext.h" #include "swift/AST/Decl.h" +#include "swift/AST/DiagnosticsSema.h" #include "swift/AST/ExistentialLayout.h" #include "swift/AST/Requirement.h" #include "swift/AST/TypeCheckRequests.h" @@ -343,6 +344,258 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator, return ctx.AllocateCopy(result); } +ArrayRef +TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator, + ProtocolDecl *proto) const { + // @objc protocols don't have associated types, so all of the below + // becomes a trivial no-op. + if (proto->isObjC()) + return ArrayRef(); + + assert(!proto->hasLazyRequirementSignature()); + + SmallVector result; + + auto &ctx = proto->getASTContext(); + + // In Verify mode, the GenericSignatureBuilder will emit the same diagnostics. + bool emitDiagnostics = + (ctx.LangOpts.RequirementMachineProtocolSignatures == + RequirementMachineMode::Enabled); + + // Collect all typealiases from inherited protocols recursively. + llvm::MapVector> inheritedTypeDecls; + for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) { + for (auto req : inheritedProto->getMembers()) { + if (auto *typeReq = dyn_cast(req)) { + // Ignore generic typealiases. + if (auto typeAliasReq = dyn_cast(typeReq)) + if (typeAliasReq->isGeneric()) + continue; + + inheritedTypeDecls[typeReq->getName()].push_back(typeReq); + } + } + } + + auto getStructuralType = [](TypeDecl *typeDecl) -> Type { + if (auto typealias = dyn_cast(typeDecl)) { + if (typealias->getUnderlyingTypeRepr() != nullptr) { + auto type = typealias->getStructuralType(); + if (auto *aliasTy = cast(type.getPointer())) + return aliasTy->getSinglyDesugaredType(); + return type; + } + return typealias->getUnderlyingType(); + } + + return typeDecl->getDeclaredInterfaceType(); + }; + + // An inferred same-type requirement between the two type declarations + // within this protocol or a protocol it inherits. + auto recordInheritedTypeRequirement = [&](TypeDecl *first, TypeDecl *second) { + desugarSameTypeRequirement(getStructuralType(first), + getStructuralType(second), result); + }; + + // Local function to find the insertion point for the protocol's "where" + // clause, as well as the string to start the insertion ("where" or ","); + auto getProtocolWhereLoc = [&]() -> Located { + // Already has a trailing where clause. + if (auto trailing = proto->getTrailingWhereClause()) + return { ", ", trailing->getRequirements().back().getSourceRange().End }; + + // Inheritance clause. + return { " where ", proto->getInherited().back().getSourceRange().End }; + }; + + // Retrieve the set of requirements that a given associated type declaration + // produces, in the form that would be seen in the where clause. + const auto getAssociatedTypeReqs = [&](const AssociatedTypeDecl *assocType, + const char *start) { + std::string result; + { + llvm::raw_string_ostream out(result); + out << start; + interleave(assocType->getInherited(), [&](TypeLoc inheritedType) { + out << assocType->getName() << ": "; + if (auto inheritedTypeRepr = inheritedType.getTypeRepr()) + inheritedTypeRepr->print(out); + else + inheritedType.getType().print(out); + }, [&] { + out << ", "; + }); + + if (const auto whereClause = assocType->getTrailingWhereClause()) { + if (!assocType->getInherited().empty()) + out << ", "; + + whereClause->print(out, /*printWhereKeyword*/false); + } + } + return result; + }; + + // Retrieve the requirement that a given typealias introduces when it + // overrides an inherited associated type with the same name, as a string + // suitable for use in a where clause. + auto getConcreteTypeReq = [&](TypeDecl *type, const char *start) { + std::string result; + { + llvm::raw_string_ostream out(result); + out << start; + out << type->getName() << " == "; + if (auto typealias = dyn_cast(type)) { + if (auto underlyingTypeRepr = typealias->getUnderlyingTypeRepr()) + underlyingTypeRepr->print(out); + else + typealias->getUnderlyingType().print(out); + } else { + type->print(out); + } + } + return result; + }; + + for (auto assocTypeDecl : proto->getAssociatedTypeMembers()) { + // Check whether we inherited any types with the same name. + auto knownInherited = + inheritedTypeDecls.find(assocTypeDecl->getName()); + if (knownInherited == inheritedTypeDecls.end()) continue; + + bool shouldWarnAboutRedeclaration = + emitDiagnostics && + !assocTypeDecl->getAttrs().hasAttribute() && + !assocTypeDecl->getAttrs().hasAttribute() && + !assocTypeDecl->hasDefaultDefinitionType() && + (!assocTypeDecl->getInherited().empty() || + assocTypeDecl->getTrailingWhereClause() || + ctx.LangOpts.WarnImplicitOverrides); + for (auto inheritedType : knownInherited->second) { + // If we have inherited associated type... + if (auto inheritedAssocTypeDecl = + dyn_cast(inheritedType)) { + // Complain about the first redeclaration. + if (shouldWarnAboutRedeclaration) { + auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol(); + auto fixItWhere = getProtocolWhereLoc(); + ctx.Diags.diagnose(assocTypeDecl, + diag::inherited_associated_type_redecl, + assocTypeDecl->getName(), + inheritedFromProto->getDeclaredInterfaceType()) + .fixItInsertAfter( + fixItWhere.Loc, + getAssociatedTypeReqs(assocTypeDecl, fixItWhere.Item)) + .fixItRemove(assocTypeDecl->getSourceRange()); + + ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here, + inheritedAssocTypeDecl->getName()); + + shouldWarnAboutRedeclaration = false; + } + + continue; + } + + if (emitDiagnostics) { + // We inherited a type; this associated type will be identical + // to that typealias. + auto inheritedOwningDecl = + inheritedType->getDeclContext()->getSelfNominalTypeDecl(); + ctx.Diags.diagnose(assocTypeDecl, + diag::associated_type_override_typealias, + assocTypeDecl->getName(), + inheritedOwningDecl->getDescriptiveKind(), + inheritedOwningDecl->getDeclaredInterfaceType()); + } + + recordInheritedTypeRequirement(assocTypeDecl, inheritedType); + } + + inheritedTypeDecls.erase(knownInherited); + } + + // Check all remaining inherited type declarations to determine if + // this protocol has a non-associated-type type with the same name. + inheritedTypeDecls.remove_if( + [&](const std::pair> &inherited) { + const auto name = inherited.first; + for (auto found : proto->lookupDirect(name)) { + // We only want concrete type declarations. + auto type = dyn_cast(found); + if (!type || isa(type)) continue; + + // Ignore nominal types. They're always invalid declarations. + if (isa(type)) + continue; + + // ... from the same module as the protocol. + if (type->getModuleContext() != proto->getModuleContext()) continue; + + // Ignore types defined in constrained extensions; their equivalence + // to the associated type would have to be conditional, which we cannot + // model. + if (auto ext = dyn_cast(type->getDeclContext())) { + if (ext->isConstrainedExtension()) continue; + } + + // We found something. + bool shouldWarnAboutRedeclaration = emitDiagnostics; + + for (auto inheritedType : inherited.second) { + // If we have inherited associated type... + if (auto inheritedAssocTypeDecl = + dyn_cast(inheritedType)) { + // Infer a same-type requirement between the typealias' underlying + // type and the inherited associated type. + recordInheritedTypeRequirement(inheritedAssocTypeDecl, type); + + // Warn that one should use where clauses for this. + if (shouldWarnAboutRedeclaration) { + auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol(); + auto fixItWhere = getProtocolWhereLoc(); + ctx.Diags.diagnose(type, + diag::typealias_override_associated_type, + name, + inheritedFromProto->getDeclaredInterfaceType()) + .fixItInsertAfter(fixItWhere.Loc, + getConcreteTypeReq(type, fixItWhere.Item)) + .fixItRemove(type->getSourceRange()); + ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here, + inheritedAssocTypeDecl->getName()); + + shouldWarnAboutRedeclaration = false; + } + + continue; + } + + // Two typealiases that should be the same. + recordInheritedTypeRequirement(inheritedType, type); + } + + // We can remove this entry. + return true; + } + + return false; + }); + + // Infer same-type requirements among inherited type declarations. + for (auto &entry : inheritedTypeDecls) { + if (entry.second.size() < 2) continue; + + auto firstDecl = entry.second.front(); + for (auto otherDecl : ArrayRef(entry.second).slice(1)) { + recordInheritedTypeRequirement(firstDecl, otherDecl); + } + } + + return ctx.AllocateCopy(result); +} + ArrayRef ProtocolDependenciesRequest::evaluate(Evaluator &evaluator, ProtocolDecl *proto) const { @@ -635,11 +888,13 @@ void RuleBuilder::collectRulesFromReferencedProtocols() { // we can trigger the computation of the requirement signatures of the // next component recursively. if (ProtocolMap[proto]) { - for (auto req : proto->getStructuralRequirements()) { - // FIXME: Keep source location information around for redundancy - // diagnostics. + // FIXME: Keep source location information around for redundancy + // diagnostics. + for (auto req : proto->getStructuralRequirements()) addRequirement(req.req.getCanonical(), proto); - } + + for (auto req : proto->getTypeAliasRequirements()) + addRequirement(req.getCanonical(), proto); } else { for (auto req : proto->getRequirementSignature()) addRequirement(req.getCanonical(), proto); diff --git a/test/Generics/protocol_typealias_same_type_requirement.swift b/test/Generics/protocol_typealias_same_type_requirement.swift new file mode 100644 index 0000000000000..84c550ee4e104 --- /dev/null +++ b/test/Generics/protocol_typealias_same_type_requirement.swift @@ -0,0 +1,45 @@ +// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -requirement-machine-protocol-signatures=on 2>&1 | %FileCheck %s + +protocol P1 { + associatedtype A +} + +protocol P2 { + associatedtype B +} + +// CHECK-LABEL: protocol_typealias_same_type_requirement.(file).P3@ +// CHECK-LABEL: Requirement signature: +protocol P3 : P1, P2 { + typealias A = B +} + +// CHECK-LABEL: protocol_typealias_same_type_requirement.(file).P4@ +// CHECK-LABEL: Requirement signature: +protocol P4 : P1, P2 { + typealias B = Int +} + +// CHECK-LABEL: protocol_typealias_same_type_requirement.(file).P5@ +// CHECK-LABEL: Requirement signature: +protocol P5 { + associatedtype A + associatedtype B +} + +extension P5 where A == Int { + typealias B = Int +} + +protocol P6 { + typealias A = Array +} + +protocol P7 { + associatedtype X + typealias A = Array +} + +// CHECK-LABEL: protocol_typealias_same_type_requirement.(file).P8@ +// CHECK-LABEL: Requirement signature: +protocol P8 : P6, P7 {} \ No newline at end of file From f0899e3acbb5e2fb63000b80d9f7afa91c1456e4 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 15:25:46 -0500 Subject: [PATCH 2/9] RequirementMachine: Make some entry points in RequirementLowering.cpp public --- .../RequirementLowering.cpp | 26 ++++++++++++++----- .../RequirementMachine/RequirementLowering.h | 17 ++++++++++-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index a06c67840b9d8..9c75228fe53b9 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -210,13 +210,24 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType, result.push_back({req, loc, wasInferred}); } -static void inferRequirements(Type type, SourceLoc loc, - SmallVectorImpl &result) { +/// Infer requirements from applications of BoundGenericTypes to type +/// parameters. For example, given a function declaration +/// +/// func union(_ x: Set, _ y: Set) +/// +/// We automatically infer 'T : Hashable' from the fact that 'struct Set' +/// declares a Hashable requirement on its generic parameter. +void swift::rewriting::inferRequirements( + Type type, SourceLoc loc, + SmallVectorImpl &result) { // FIXME: Implement } -static void realizeRequirement(Requirement req, RequirementRepr *reqRepr, bool infer, - SmallVectorImpl &result) { +/// Desugar a requirement and perform requirement inference if requested +/// to obtain zero or more structural requirements. +void swift::rewriting::realizeRequirement( + Requirement req, RequirementRepr *reqRepr, bool infer, + SmallVectorImpl &result) { auto firstType = req.getFirstType(); if (infer) { auto firstLoc = (reqRepr ? reqRepr->getFirstTypeRepr()->getStartLoc() @@ -269,8 +280,11 @@ static void realizeRequirement(Requirement req, RequirementRepr *reqRepr, bool i } } -static void realizeInheritedRequirements(TypeDecl *decl, Type type, bool infer, - SmallVectorImpl &result) { +/// Collect structural requirements written in the inheritance clause of an +/// AssociatedTypeDecl or GenericTypeParamDecl. +void swift::rewriting::realizeInheritedRequirements( + TypeDecl *decl, Type type, bool infer, + SmallVectorImpl &result) { auto &ctx = decl->getASTContext(); auto inheritedTypes = decl->getInherited(); diff --git a/lib/AST/RequirementMachine/RequirementLowering.h b/lib/AST/RequirementMachine/RequirementLowering.h index a5d55d1c59317..2f6e65051a923 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.h +++ b/lib/AST/RequirementMachine/RequirementLowering.h @@ -34,8 +34,21 @@ class Requirement; namespace rewriting { -void -desugarRequirement(Requirement req, SmallVectorImpl &result); +// Entry points used by AbstractGenericSignatureRequest and +// InferredGenericSignatureRequest; see RequiremetnLowering.cpp for +// documentation +// comments. + +void desugarRequirement(Requirement req, SmallVectorImpl &result); + +void inferRequirements(Type type, SourceLoc loc, + SmallVectorImpl &result); + +void realizeRequirement(Requirement req, RequirementRepr *reqRepr, bool infer, + SmallVectorImpl &result); + +void realizeInheritedRequirements(TypeDecl *decl, Type type, bool infer, + SmallVectorImpl &result); /// A utility class for bulding rewrite rules from the top-level requirements /// of a generic signature. From 1c78b0466b567f836d538fa717bb6bccd75ba561 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 15:42:11 -0500 Subject: [PATCH 3/9] RequirementMachine: Clean up the RequirementMachine::initWith*() methods a bit --- .../RequirementMachine/RequirementMachine.cpp | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/lib/AST/RequirementMachine/RequirementMachine.cpp b/lib/AST/RequirementMachine/RequirementMachine.cpp index d9ce94b04ad4d..4b0ecd1de3bfd 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.cpp +++ b/lib/AST/RequirementMachine/RequirementMachine.cpp @@ -139,6 +139,9 @@ RequirementMachine::RequirementMachine(RewriteContext &ctx) RequirementMachineStepLimit = langOpts.RequirementMachineStepLimit; RequirementMachineDepthLimit = langOpts.RequirementMachineDepthLimit; Stats = ctx.getASTContext().Stats; + + if (Stats) + ++Stats->getFrontendCounters().NumRequirementMachines; } RequirementMachine::~RequirementMachine() {} @@ -147,6 +150,8 @@ RequirementMachine::~RequirementMachine() {} /// /// This must only be called exactly once, before any other operations are /// performed on this requirement machine. +/// +/// Used by ASTContext::getOrCreateRequirementMachine(). void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) { Sig = sig; Params.append(sig.getGenericParams().begin(), @@ -154,12 +159,6 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) { PrettyStackTraceGenericSignature debugStack("building rewrite system for", sig); - auto &ctx = Context.getASTContext(); - auto *Stats = ctx.Stats; - - if (Stats) - ++Stats->getFrontendCounters().NumRequirementMachines; - FrontendStatsTracer tracer(Stats, "build-rewrite-system"); if (Dump) { @@ -189,15 +188,11 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) { /// /// This must only be called exactly once, before any other operations are /// performed on this requirement machine. +/// +/// Used by RequirementSignatureRequest. void RequirementMachine::initWithProtocols(ArrayRef protos) { Protos = protos; - auto &ctx = Context.getASTContext(); - auto *Stats = ctx.Stats; - - if (Stats) - ++Stats->getFrontendCounters().NumRequirementMachines; - FrontendStatsTracer tracer(Stats, "build-rewrite-system"); if (Dump) { @@ -225,18 +220,17 @@ void RequirementMachine::initWithProtocols(ArrayRef protos } /// Build a requirement machine from a set of generic parameters and -/// (possibly non-canonical or non-minimal) structural requirements. +/// (possibly non-canonical or non-minimal) abstract requirements. +/// +/// This must only be called exactly once, before any other operations are +/// performed on this requirement machine. +/// +/// Used by AbstractGenericSignatureRequest. void RequirementMachine::initWithAbstractRequirements( ArrayRef genericParams, ArrayRef requirements) { Params.append(genericParams.begin(), genericParams.end()); - auto &ctx = Context.getASTContext(); - auto *Stats = ctx.Stats; - - if (Stats) - ++Stats->getFrontendCounters().NumRequirementMachines; - FrontendStatsTracer tracer(Stats, "build-rewrite-system"); if (Dump) { From 42c0a28ad7001f30007475a8285d5fc368a4715c Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 15:45:51 -0500 Subject: [PATCH 4/9] RequirementMachine: Add RequirementMachine::initWithWrittenRequirements() --- .../RequirementLowering.cpp | 25 ++++++++++-- .../RequirementMachine/RequirementLowering.h | 3 ++ .../RequirementMachine/RequirementMachine.cpp | 38 +++++++++++++++++++ .../RequirementMachine/RequirementMachine.h | 3 ++ 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index 9c75228fe53b9..105f66989040f 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -692,6 +692,21 @@ void RuleBuilder::addRequirements(ArrayRef requirements) { addRequirement(req, /*proto=*/nullptr); } +void RuleBuilder::addRequirements(ArrayRef requirements) { + // Collect all protocols transitively referenced from these requirements. + for (auto req : requirements) { + if (req.req.getKind() == RequirementKind::Conformance) { + addProtocol(req.req.getProtocolDecl(), /*initialComponent=*/false); + } + } + + collectRulesFromReferencedProtocols(); + + // Add rewrite rules for all top-level requirements. + for (const auto &req : requirements) + addRequirement(req, /*proto=*/nullptr); +} + void RuleBuilder::addProtocols(ArrayRef protos) { // Collect all protocols transitively referenced from this connected component // of the protocol dependency graph. @@ -849,6 +864,12 @@ void RuleBuilder::addRequirement(const Requirement &req, RequirementRules.emplace_back(subjectTerm, constraintTerm); } +void RuleBuilder::addRequirement(const StructuralRequirement &req, + const ProtocolDecl *proto) { + // FIXME: Preserve source location information for diagnostics. + addRequirement(req.req.getCanonical(), proto); +} + /// Record information about a protocol if we have no seen it yet. void RuleBuilder::addProtocol(const ProtocolDecl *proto, bool initialComponent) { @@ -902,10 +923,8 @@ void RuleBuilder::collectRulesFromReferencedProtocols() { // we can trigger the computation of the requirement signatures of the // next component recursively. if (ProtocolMap[proto]) { - // FIXME: Keep source location information around for redundancy - // diagnostics. for (auto req : proto->getStructuralRequirements()) - addRequirement(req.req.getCanonical(), proto); + addRequirement(req, proto); for (auto req : proto->getTypeAliasRequirements()) addRequirement(req.getCanonical(), proto); diff --git a/lib/AST/RequirementMachine/RequirementLowering.h b/lib/AST/RequirementMachine/RequirementLowering.h index 2f6e65051a923..d84142438e1c6 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.h +++ b/lib/AST/RequirementMachine/RequirementLowering.h @@ -93,6 +93,7 @@ struct RuleBuilder { RuleBuilder(RewriteContext &ctx, bool dump) : Context(ctx), Dump(dump) {} void addRequirements(ArrayRef requirements); + void addRequirements(ArrayRef requirements); void addProtocols(ArrayRef proto); void addProtocol(const ProtocolDecl *proto, bool initialComponent); @@ -100,6 +101,8 @@ struct RuleBuilder { const ProtocolDecl *proto); void addRequirement(const Requirement &req, const ProtocolDecl *proto); + void addRequirement(const StructuralRequirement &req, + const ProtocolDecl *proto); void collectRulesFromReferencedProtocols(); }; diff --git a/lib/AST/RequirementMachine/RequirementMachine.cpp b/lib/AST/RequirementMachine/RequirementMachine.cpp index 4b0ecd1de3bfd..f6c50ff62cb79 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.cpp +++ b/lib/AST/RequirementMachine/RequirementMachine.cpp @@ -257,6 +257,44 @@ void RequirementMachine::initWithAbstractRequirements( } } +/// Build a requirement machine from a set of generic parameters and +/// structural requirements. +/// +/// This must only be called exactly once, before any other operations are +/// performed on this requirement machine. +/// +/// Used by InferredGenericSignatureRequest. +void RequirementMachine::initWithWrittenRequirements( + ArrayRef genericParams, + ArrayRef requirements) { + Params.append(genericParams.begin(), genericParams.end()); + + FrontendStatsTracer tracer(Stats, "build-rewrite-system"); + + if (Dump) { + llvm::dbgs() << "Adding generic parameters:"; + for (auto *paramTy : genericParams) + llvm::dbgs() << " " << Type(paramTy); + llvm::dbgs() << "\n"; + } + + // Collect the top-level requirements, and all transtively-referenced + // protocol requirement signatures. + RuleBuilder builder(Context, Dump); + builder.addRequirements(requirements); + + // Add the initial set of rewrite rules to the rewrite system. + System.initialize(/*recordLoops=*/true, + std::move(builder.PermanentRules), + std::move(builder.RequirementRules)); + + computeCompletion(RewriteSystem::AllowInvalidRequirements); + + if (Dump) { + llvm::dbgs() << "}\n"; + } +} + /// Attempt to obtain a confluent rewrite system by iterating the Knuth-Bendix /// completion procedure together with property map construction until fixed /// point. diff --git a/lib/AST/RequirementMachine/RequirementMachine.h b/lib/AST/RequirementMachine/RequirementMachine.h index c2d76184f1649..0d6234046f3ac 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.h +++ b/lib/AST/RequirementMachine/RequirementMachine.h @@ -86,6 +86,9 @@ class RequirementMachine final { void initWithAbstractRequirements( ArrayRef genericParams, ArrayRef requirements); + void initWithWrittenRequirements( + ArrayRef genericParams, + ArrayRef requirements); bool isComplete() const; From af9ea678ed99d0c197f6b57bd6fc43ba7673143b Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 14:52:32 -0500 Subject: [PATCH 5/9] RequirementMachine: Implement InferredGenericSignatureRequest --- include/swift/AST/TypeCheckRequests.h | 45 +++++++- include/swift/AST/TypeCheckerTypeIDZone.def | 8 ++ .../RequirementMachine/RequirementMachine.h | 2 + .../RequirementMachineRequests.cpp | 106 ++++++++++++++++++ lib/AST/TypeCheckRequests.cpp | 8 ++ 5 files changed, 168 insertions(+), 1 deletion(-) diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 532133226e3c2..4f516ff6997f3 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -1553,7 +1553,50 @@ class InferredGenericSignatureRequest : SourceLoc getNearestLoc() const { return SourceLoc(); } - + + // Cycle handling. + void noteCycleStep(DiagnosticEngine &diags) const; +}; + +/// Build a generic signature using the RequirementMachine. This is temporary; +/// once the GenericSignatureBuilder goes away this will be folded into +/// InferredGenericSignatureRequest. +class InferredGenericSignatureRequestRQM : + public SimpleRequest, + SmallVector, + bool), + RequestFlags::Cached> { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + GenericSignatureWithError + evaluate(Evaluator &evaluator, + ModuleDecl *parentModule, + const GenericSignatureImpl *baseSignature, + GenericParamList *genericParams, + WhereClauseOwner whereClause, + SmallVector addedRequirements, + SmallVector inferenceSources, + bool allowConcreteGenericParams) const; + +public: + // Separate caching. + bool isCached() const { return true; } + + /// Inferred generic signature requests don't have source-location info. + SourceLoc getNearestLoc() const { + return SourceLoc(); + } + // Cycle handling. void noteCycleStep(DiagnosticEngine &diags) const; }; diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index b34e85806f03f..da2a9b8027f19 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -143,6 +143,14 @@ SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest, SmallVector, SmallVector, bool), Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequestRQM, + GenericSignatureWithError (ModuleDecl *, + const GenericSignatureImpl *, + GenericParamList *, + WhereClauseOwner, + SmallVector, + SmallVector, bool), + Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest, bool(ModuleDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, InheritedTypeRequest, diff --git a/lib/AST/RequirementMachine/RequirementMachine.h b/lib/AST/RequirementMachine/RequirementMachine.h index 0d6234046f3ac..f43ff08965b67 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.h +++ b/lib/AST/RequirementMachine/RequirementMachine.h @@ -32,6 +32,7 @@ class ASTContext; class AssociatedTypeDecl; class CanType; class GenericTypeParamType; +class InferredGenericSignatureRequestRQM; class LayoutConstraint; class ProtocolDecl; class Requirement; @@ -47,6 +48,7 @@ class RequirementMachine final { friend class swift::ASTContext; friend class swift::rewriting::RewriteContext; friend class swift::AbstractGenericSignatureRequestRQM; + friend class swift::InferredGenericSignatureRequestRQM; CanGenericSignature Sig; SmallVector Params; diff --git a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp index 676adf28e0f46..4fb7508b33b6d 100644 --- a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp +++ b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp @@ -25,6 +25,7 @@ #include "swift/AST/LazyResolver.h" #include "swift/AST/Requirement.h" #include "swift/AST/TypeCheckRequests.h" +#include "swift/AST/TypeRepr.h" #include "swift/Basic/Statistic.h" #include "RequirementLowering.h" #include @@ -357,5 +358,110 @@ AbstractGenericSignatureRequestRQM::evaluate( bool hadError = false; auto result = GenericSignature::get(genericParams, minimalRequirements); + return GenericSignatureWithError(result, hadError); +} + +GenericSignatureWithError +InferredGenericSignatureRequestRQM::evaluate( + Evaluator &evaluator, + ModuleDecl *parentModule, + const GenericSignatureImpl *parentSigImpl, + GenericParamList *genericParamList, + WhereClauseOwner whereClause, + SmallVector addedRequirements, + SmallVector inferenceSources, + bool allowConcreteGenericParams) const { + GenericSignature parentSig(parentSigImpl); + + SmallVector genericParams( + parentSig.getGenericParams().begin(), + parentSig.getGenericParams().end()); + + SmallVector requirements; + for (const auto &req : parentSig.getRequirements()) + requirements.push_back({req, SourceLoc(), /*wasInferred=*/false}); + + const auto visitRequirement = [&](const Requirement &req, + RequirementRepr *reqRepr) { + realizeRequirement(req, reqRepr, /*infer=*/true, requirements); + return false; + }; + + if (genericParamList) { + // Extensions never have a parent signature. + assert(genericParamList->getOuterParameters() == nullptr || !parentSig); + + // Collect all outer generic parameter lists. + SmallVector gpLists; + for (auto *outerParamList = genericParamList; + outerParamList != nullptr; + outerParamList = outerParamList->getOuterParameters()) { + gpLists.push_back(outerParamList); + } + + // The generic parameter lists must appear from innermost to outermost. + // We walk them backwards to order outer parameters before inner + // parameters. + for (auto *gpList : llvm::reverse(gpLists)) { + assert(gpList->size() > 0 && + "Parsed an empty generic parameter list?"); + + for (auto *gpDecl : *gpList) { + auto *gpType = gpDecl->getDeclaredInterfaceType() + ->castTo(); + genericParams.push_back(gpType); + + realizeInheritedRequirements(gpDecl, gpType, /*infer=*/true, + requirements); + } + + // Add the generic parameter list's 'where' clause to the builder. + // + // The only time generic parameter lists have a 'where' clause is + // in SIL mode; all other generic declarations have a free-standing + // 'where' clause, which will be visited below. + WhereClauseOwner(parentModule, gpList) + .visitRequirements(TypeResolutionStage::Structural, + visitRequirement); + } + } + + if (whereClause) { + std::move(whereClause).visitRequirements( + TypeResolutionStage::Structural, + visitRequirement); + } + + // Perform requirement inference from function parameter and result + // types and such. + for (auto sourcePair : inferenceSources) { + auto *typeRepr = sourcePair.getTypeRepr(); + auto loc = typeRepr ? typeRepr->getStartLoc() : SourceLoc(); + + inferRequirements(sourcePair.getType(), loc, requirements); + } + + // Finish by adding any remaining requirements. This is used to introduce + // inferred same-type requirements when building the generic signature of + // an extension whose extended type is a generic typealias. + for (const auto &req : addedRequirements) + requirements.push_back({req, SourceLoc(), /*wasInferred=*/true}); + + // Heap-allocate the requirement machine to save stack space. + std::unique_ptr machine(new RequirementMachine( + parentModule->getASTContext().getRewriteContext())); + + machine->initWithWrittenRequirements(genericParams, requirements); + + auto minimalRequirements = + machine->computeMinimalGenericSignatureRequirements(); + + // FIXME: Implement this + bool hadError = false; + + auto result = GenericSignature::get(genericParams, minimalRequirements); + + // FIXME: Handle allowConcreteGenericParams + return GenericSignatureWithError(result, hadError); } \ No newline at end of file diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp index c51c0eb951671..475d21e7d1412 100644 --- a/lib/AST/TypeCheckRequests.cpp +++ b/lib/AST/TypeCheckRequests.cpp @@ -766,6 +766,14 @@ void InferredGenericSignatureRequest::noteCycleStep(DiagnosticEngine &d) const { // into this request. See rdar://55263708 } +void InferredGenericSignatureRequestRQM::noteCycleStep(DiagnosticEngine &d) const { + // For now, the GSB does a better job of describing the exact structure of + // the cycle. + // + // FIXME: We should consider merging the circularity handling the GSB does + // into this request. See rdar://55263708 +} + //----------------------------------------------------------------------------// // UnderlyingTypeRequest computation. //----------------------------------------------------------------------------// From 163293e6f91b1760f0561a63db21947fdfe851af Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 16:05:13 -0500 Subject: [PATCH 6/9] RequirementMachine: Split up -requirement-machine-generic-signatures flag into -requirement-machine-{abstract,inferred}-signatures --- include/swift/Basic/LangOptions.h | 10 ++++++++-- include/swift/Option/Options.td | 6 +++++- lib/AST/GenericSignatureBuilder.cpp | 2 +- lib/Frontend/CompilerInvocation.cpp | 18 ++++++++++++++++-- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 97889ab1f66c1..5aad9ec12d184 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -489,8 +489,14 @@ namespace swift { RequirementMachineMode RequirementMachineProtocolSignatures = RequirementMachineMode::Disabled; - /// Enable the new experimental generic signature minimization algorithm. - RequirementMachineMode RequirementMachineGenericSignatures = + /// Enable the new experimental generic signature minimization algorithm + /// for abstract generic signatures. + RequirementMachineMode RequirementMachineAbstractSignatures = + RequirementMachineMode::Disabled; + + /// Enable the new experimental generic signature minimization algorithm + /// for user-written generic signatures. + RequirementMachineMode RequirementMachineInferredSignatures = RequirementMachineMode::Disabled; /// Sets the target we are building for and updates platform conditions diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index 29338aea4c6b7..5df786fd2d044 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -624,7 +624,11 @@ def requirement_machine_protocol_signatures_EQ : Joined<["-"], "requirement-mach Flags<[FrontendOption]>, HelpText<"Control usage of experimental protocol requirement signature minimization: 'on', 'off', or 'verify'">; -def requirement_machine_generic_signatures_EQ : Joined<["-"], "requirement-machine-generic-signatures=">, +def requirement_machine_abstract_signatures_EQ : Joined<["-"], "requirement-machine-abstract-signatures=">, + Flags<[FrontendOption]>, + HelpText<"Control usage of experimental generic signature minimization: 'on', 'off', or 'verify'">; + +def requirement_machine_inferred_signatures_EQ : Joined<["-"], "requirement-machine-inferred-signatures=">, Flags<[FrontendOption]>, HelpText<"Control usage of experimental generic signature minimization: 'on', 'off', or 'verify'">; diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index bbecef03ce454..f98486b74af8d 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -8612,7 +8612,7 @@ AbstractGenericSignatureRequest::evaluate( GenericSignatureWithError()); }; - switch (ctx.LangOpts.RequirementMachineGenericSignatures) { + switch (ctx.LangOpts.RequirementMachineAbstractSignatures) { case RequirementMachineMode::Disabled: return buildViaGSB(); diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 34bbdc5d059bd..36f5a490b1284 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -864,7 +864,7 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, A->getAsString(Args), A->getValue()); } - if (auto A = Args.getLastArg(OPT_requirement_machine_generic_signatures_EQ)) { + if (auto A = Args.getLastArg(OPT_requirement_machine_abstract_signatures_EQ)) { auto value = llvm::StringSwitch>(A->getValue()) .Case("off", RequirementMachineMode::Disabled) .Case("on", RequirementMachineMode::Enabled) @@ -872,7 +872,21 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, .Default(None); if (value) - Opts.RequirementMachineGenericSignatures = *value; + Opts.RequirementMachineAbstractSignatures = *value; + else + Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value, + A->getAsString(Args), A->getValue()); + } + + if (auto A = Args.getLastArg(OPT_requirement_machine_inferred_signatures_EQ)) { + auto value = llvm::StringSwitch>(A->getValue()) + .Case("off", RequirementMachineMode::Disabled) + .Case("on", RequirementMachineMode::Enabled) + .Case("verify", RequirementMachineMode::Verify) + .Default(None); + + if (value) + Opts.RequirementMachineInferredSignatures = *value; else Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value, A->getAsString(Args), A->getValue()); From 97e160b7ab470099acf8c719077dc9bd0e759a72 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 16:38:28 -0500 Subject: [PATCH 7/9] RequirementMachine: Wire up -requirement-machine-inferred-signatures flag --- lib/AST/GenericSignatureBuilder.cpp | 241 ++++++++++++++++------------ 1 file changed, 142 insertions(+), 99 deletions(-) diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index f98486b74af8d..c37d03593d908 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -8649,121 +8649,164 @@ InferredGenericSignatureRequest::evaluate( SmallVector addedRequirements, SmallVector inferenceSources, bool allowConcreteGenericParams) const { - - GenericSignatureBuilder builder(parentModule->getASTContext()); - - // If there is a parent context, add the generic parameters and requirements - // from that context. - builder.addGenericSignature(parentSig); - - DeclContext *lookupDC = nullptr; - - const auto visitRequirement = [&](const Requirement &req, - RequirementRepr *reqRepr) { - const auto source = FloatingRequirementSource::forExplicit( - reqRepr->getSeparatorLoc()); - - // If we're extending a protocol and adding a redundant requirement, - // for example, `extension Foo where Self: Foo`, then emit a - // diagnostic. - - if (auto decl = lookupDC->getAsDecl()) { - if (auto extDecl = dyn_cast(decl)) { - auto extType = extDecl->getDeclaredInterfaceType(); - auto extSelfType = extDecl->getSelfInterfaceType(); - auto reqLHSType = req.getFirstType(); - auto reqRHSType = req.getSecondType(); - - if (extType->isExistentialType() && - reqLHSType->isEqual(extSelfType) && - reqRHSType->isEqual(extType)) { - - auto &ctx = extDecl->getASTContext(); - ctx.Diags.diagnose(extDecl->getLoc(), - diag::protocol_extension_redundant_requirement, - extType->getString(), - extSelfType->getString(), - reqRHSType->getString()); + auto buildViaGSB = [&]() { + GenericSignatureBuilder builder(parentModule->getASTContext()); + + // If there is a parent context, add the generic parameters and requirements + // from that context. + builder.addGenericSignature(parentSig); + + DeclContext *lookupDC = nullptr; + + const auto visitRequirement = [&](const Requirement &req, + RequirementRepr *reqRepr) { + const auto source = FloatingRequirementSource::forExplicit( + reqRepr->getSeparatorLoc()); + + // If we're extending a protocol and adding a redundant requirement, + // for example, `extension Foo where Self: Foo`, then emit a + // diagnostic. + + if (auto decl = lookupDC->getAsDecl()) { + if (auto extDecl = dyn_cast(decl)) { + auto extType = extDecl->getDeclaredInterfaceType(); + auto extSelfType = extDecl->getSelfInterfaceType(); + auto reqLHSType = req.getFirstType(); + auto reqRHSType = req.getSecondType(); + + if (extType->isExistentialType() && + reqLHSType->isEqual(extSelfType) && + reqRHSType->isEqual(extType)) { + + auto &ctx = extDecl->getASTContext(); + ctx.Diags.diagnose(extDecl->getLoc(), + diag::protocol_extension_redundant_requirement, + extType->getString(), + extSelfType->getString(), + reqRHSType->getString()); + } } } + + builder.addRequirement(req, reqRepr, source, nullptr, + lookupDC->getParentModule()); + return false; + }; + + if (genericParams) { + // Extensions never have a parent signature. + if (genericParams->getOuterParameters()) + assert(parentSig == nullptr); + + // Type check the generic parameters, treating all generic type + // parameters as dependent, unresolved. + SmallVector gpLists; + for (auto *outerParams = genericParams; + outerParams != nullptr; + outerParams = outerParams->getOuterParameters()) { + gpLists.push_back(outerParams); + } + + // The generic parameter lists MUST appear from innermost to outermost. + // We walk them backwards to order outer requirements before + // inner requirements. + for (auto &genericParams : llvm::reverse(gpLists)) { + assert(genericParams->size() > 0 && + "Parsed an empty generic parameter list?"); + + // First, add the generic parameters to the generic signature builder. + // Do this before checking the inheritance clause, since it may + // itself be dependent on one of these parameters. + for (const auto param : *genericParams) + builder.addGenericParameter(param); + + // Add the requirements for each of the generic parameters to the builder. + // Now, check the inheritance clauses of each parameter. + for (const auto param : *genericParams) + builder.addGenericParameterRequirements(param); + + // Determine where and how to perform name lookup. + lookupDC = genericParams->begin()[0]->getDeclContext(); + + // Add the requirements clause to the builder. + WhereClauseOwner(lookupDC, genericParams) + .visitRequirements(TypeResolutionStage::Structural, + visitRequirement); + } } - builder.addRequirement(req, reqRepr, source, nullptr, - lookupDC->getParentModule()); - return false; + if (whereClause) { + lookupDC = whereClause.dc; + std::move(whereClause).visitRequirements( + TypeResolutionStage::Structural, visitRequirement); + } + + /// Perform any remaining requirement inference. + for (auto sourcePair : inferenceSources) { + auto *typeRepr = sourcePair.getTypeRepr(); + auto source = + FloatingRequirementSource::forInferred( + typeRepr ? typeRepr->getStartLoc() : SourceLoc()); + + builder.inferRequirements(*parentModule, + sourcePair.getType(), + source); + } + + // Finish by adding any remaining requirements. + auto source = + FloatingRequirementSource::forInferred(SourceLoc()); + + for (const auto &req : addedRequirements) + builder.addRequirement(req, source, parentModule); + + bool hadError = builder.hadAnyError(); + auto result = std::move(builder).computeGenericSignature( + allowConcreteGenericParams); + return GenericSignatureWithError(result, hadError); }; - if (genericParams) { - // Extensions never have a parent signature. - if (genericParams->getOuterParameters()) - assert(parentSig == nullptr); + auto &ctx = parentModule->getASTContext(); - // Type check the generic parameters, treating all generic type - // parameters as dependent, unresolved. - SmallVector gpLists; - for (auto *outerParams = genericParams; - outerParams != nullptr; - outerParams = outerParams->getOuterParameters()) { - gpLists.push_back(outerParams); - } + auto buildViaRQM = [&]() { + return evaluateOrDefault( + ctx.evaluator, + InferredGenericSignatureRequestRQM{ + parentModule, + parentSig, + genericParams, + whereClause, + addedRequirements, + inferenceSources, + allowConcreteGenericParams}, + GenericSignatureWithError()); + }; - // The generic parameter lists MUST appear from innermost to outermost. - // We walk them backwards to order outer requirements before - // inner requirements. - for (auto &genericParams : llvm::reverse(gpLists)) { - assert(genericParams->size() > 0 && - "Parsed an empty generic parameter list?"); + switch (ctx.LangOpts.RequirementMachineInferredSignatures) { + case RequirementMachineMode::Disabled: + return buildViaGSB(); - // First, add the generic parameters to the generic signature builder. - // Do this before checking the inheritance clause, since it may - // itself be dependent on one of these parameters. - for (const auto param : *genericParams) - builder.addGenericParameter(param); + case RequirementMachineMode::Enabled: + return buildViaRQM(); - // Add the requirements for each of the generic parameters to the builder. - // Now, check the inheritance clauses of each parameter. - for (const auto param : *genericParams) - builder.addGenericParameterRequirements(param); + case RequirementMachineMode::Verify: { + auto rqmResult = buildViaRQM(); + auto gsbResult = buildViaGSB(); - // Determine where and how to perform name lookup. - lookupDC = genericParams->begin()[0]->getDeclContext(); + if (!rqmResult.getPointer() && !gsbResult.getPointer()) + return gsbResult; - // Add the requirements clause to the builder. - WhereClauseOwner(lookupDC, genericParams) - .visitRequirements(TypeResolutionStage::Structural, - visitRequirement); + if (!rqmResult.getPointer()->isEqual(gsbResult.getPointer())) { + llvm::errs() << "RequirementMachine generic signature minimization is broken:\n"; + llvm::errs() << "RequirementMachine says: " << rqmResult.getPointer() << "\n"; + llvm::errs() << "GenericSignatureBuilder says: " << gsbResult.getPointer() << "\n"; + + abort(); } - } - if (whereClause) { - lookupDC = whereClause.dc; - std::move(whereClause).visitRequirements( - TypeResolutionStage::Structural, visitRequirement); + return gsbResult; } - - /// Perform any remaining requirement inference. - for (auto sourcePair : inferenceSources) { - auto *typeRepr = sourcePair.getTypeRepr(); - auto source = - FloatingRequirementSource::forInferred( - typeRepr ? typeRepr->getStartLoc() : SourceLoc()); - - builder.inferRequirements(*parentModule, - sourcePair.getType(), - source); } - - // Finish by adding any remaining requirements. - auto source = - FloatingRequirementSource::forInferred(SourceLoc()); - - for (const auto &req : addedRequirements) - builder.addRequirement(req, source, parentModule); - - bool hadError = builder.hadAnyError(); - auto result = std::move(builder).computeGenericSignature( - allowConcreteGenericParams); - return GenericSignatureWithError(result, hadError); } ArrayRef From 291ddd7a31b25d3fb4f94d1bbc79228e6527e0ee Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 17:03:11 -0500 Subject: [PATCH 8/9] RequirementMachine: Plumb through the ModuleDecl used for requirement inference --- .../RequirementLowering.cpp | 45 ++++++++++--------- .../RequirementMachine/RequirementLowering.h | 8 ++-- .../RequirementMachineRequests.cpp | 10 +++-- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index 105f66989040f..c5e9aec2580a7 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -186,7 +186,7 @@ swift::rewriting::desugarRequirement(Requirement req, // static void realizeTypeRequirement(Type subjectType, Type constraintType, - SourceLoc loc, bool wasInferred, + SourceLoc loc, SmallVectorImpl &result) { // Check whether we have a reasonable constraint type at all. if (!constraintType->isExistentialType() && @@ -207,7 +207,7 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType, // Add source location information. for (auto req : reqs) - result.push_back({req, loc, wasInferred}); + result.push_back({req, loc, /*wasInferred=*/false}); } /// Infer requirements from applications of BoundGenericTypes to type @@ -218,7 +218,7 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType, /// We automatically infer 'T : Hashable' from the fact that 'struct Set' /// declares a Hashable requirement on its generic parameter. void swift::rewriting::inferRequirements( - Type type, SourceLoc loc, + Type type, SourceLoc loc, ModuleDecl *module, SmallVectorImpl &result) { // FIXME: Implement } @@ -226,13 +226,14 @@ void swift::rewriting::inferRequirements( /// Desugar a requirement and perform requirement inference if requested /// to obtain zero or more structural requirements. void swift::rewriting::realizeRequirement( - Requirement req, RequirementRepr *reqRepr, bool infer, + Requirement req, RequirementRepr *reqRepr, + ModuleDecl *moduleForInference, SmallVectorImpl &result) { auto firstType = req.getFirstType(); - if (infer) { + if (moduleForInference) { auto firstLoc = (reqRepr ? reqRepr->getFirstTypeRepr()->getStartLoc() : SourceLoc()); - inferRequirements(firstType, firstLoc, result); + inferRequirements(firstType, firstLoc, moduleForInference, result); } auto loc = (reqRepr ? reqRepr->getSeparatorLoc() : SourceLoc()); @@ -241,14 +242,13 @@ void swift::rewriting::realizeRequirement( case RequirementKind::Superclass: case RequirementKind::Conformance: { auto secondType = req.getSecondType(); - if (infer) { + if (moduleForInference) { auto secondLoc = (reqRepr ? reqRepr->getSecondTypeRepr()->getStartLoc() : SourceLoc()); - inferRequirements(secondType, secondLoc, result); + inferRequirements(secondType, secondLoc, moduleForInference, result); } - realizeTypeRequirement(firstType, secondType, loc, /*wasInferred=*/false, - result); + realizeTypeRequirement(firstType, secondType, loc, result); break; } @@ -264,10 +264,10 @@ void swift::rewriting::realizeRequirement( case RequirementKind::SameType: { auto secondType = req.getSecondType(); - if (infer) { + if (moduleForInference) { auto secondLoc = (reqRepr ? reqRepr->getSecondTypeRepr()->getStartLoc() : SourceLoc()); - inferRequirements(secondType, secondLoc, result); + inferRequirements(secondType, secondLoc, moduleForInference, result); } SmallVector reqs; @@ -283,7 +283,7 @@ void swift::rewriting::realizeRequirement( /// Collect structural requirements written in the inheritance clause of an /// AssociatedTypeDecl or GenericTypeParamDecl. void swift::rewriting::realizeInheritedRequirements( - TypeDecl *decl, Type type, bool infer, + TypeDecl *decl, Type type, ModuleDecl *moduleForInference, SmallVectorImpl &result) { auto &ctx = decl->getASTContext(); auto inheritedTypes = decl->getInherited(); @@ -298,12 +298,11 @@ void swift::rewriting::realizeInheritedRequirements( auto *typeRepr = inheritedTypes[index].getTypeRepr(); SourceLoc loc = (typeRepr ? typeRepr->getStartLoc() : SourceLoc()); - if (infer) { - inferRequirements(inheritedType, loc, result); + if (moduleForInference) { + inferRequirements(inheritedType, loc, moduleForInference, result); } - realizeTypeRequirement(type, inheritedType, loc, /*wasInferred=*/false, - result); + realizeTypeRequirement(type, inheritedType, loc, result); } } @@ -319,12 +318,13 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator, auto selfTy = proto->getSelfInterfaceType(); realizeInheritedRequirements(proto, selfTy, - /*infer=*/false, result); + /*moduleForInference=*/nullptr, result); // Add requirements from the protocol's own 'where' clause. WhereClauseOwner(proto).visitRequirements(TypeResolutionStage::Structural, [&](const Requirement &req, RequirementRepr *reqRepr) { - realizeRequirement(req, reqRepr, /*infer=*/false, result); + realizeRequirement(req, reqRepr, + /*moduleForInference=*/nullptr, result); return false; }); @@ -343,14 +343,17 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator, for (auto assocTypeDecl : proto->getAssociatedTypeMembers()) { // Add requirements placed directly on this associated type. auto assocType = assocTypeDecl->getDeclaredInterfaceType(); - realizeInheritedRequirements(assocTypeDecl, assocType, /*infer=*/false, + realizeInheritedRequirements(assocTypeDecl, assocType, + /*moduleForInference=*/nullptr, result); // Add requirements from this associated type's where clause. WhereClauseOwner(assocTypeDecl).visitRequirements( TypeResolutionStage::Structural, [&](const Requirement &req, RequirementRepr *reqRepr) { - realizeRequirement(req, reqRepr, /*infer=*/false, result); + realizeRequirement(req, reqRepr, + /*moduleForInference=*/nullptr, + result); return false; }); } diff --git a/lib/AST/RequirementMachine/RequirementLowering.h b/lib/AST/RequirementMachine/RequirementLowering.h index d84142438e1c6..68f68285f5d88 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.h +++ b/lib/AST/RequirementMachine/RequirementLowering.h @@ -41,13 +41,15 @@ namespace rewriting { void desugarRequirement(Requirement req, SmallVectorImpl &result); -void inferRequirements(Type type, SourceLoc loc, +void inferRequirements(Type type, SourceLoc loc, ModuleDecl *module, SmallVectorImpl &result); -void realizeRequirement(Requirement req, RequirementRepr *reqRepr, bool infer, +void realizeRequirement(Requirement req, RequirementRepr *reqRepr, + ModuleDecl *moduleForInference, SmallVectorImpl &result); -void realizeInheritedRequirements(TypeDecl *decl, Type type, bool infer, +void realizeInheritedRequirements(TypeDecl *decl, Type type, + ModuleDecl *moduleForInference, SmallVectorImpl &result); /// A utility class for bulding rewrite rules from the top-level requirements diff --git a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp index 4fb7508b33b6d..431dffd56f3ac 100644 --- a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp +++ b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp @@ -373,6 +373,8 @@ InferredGenericSignatureRequestRQM::evaluate( bool allowConcreteGenericParams) const { GenericSignature parentSig(parentSigImpl); + auto &ctx = parentModule->getASTContext(); + SmallVector genericParams( parentSig.getGenericParams().begin(), parentSig.getGenericParams().end()); @@ -383,7 +385,7 @@ InferredGenericSignatureRequestRQM::evaluate( const auto visitRequirement = [&](const Requirement &req, RequirementRepr *reqRepr) { - realizeRequirement(req, reqRepr, /*infer=*/true, requirements); + realizeRequirement(req, reqRepr, parentModule, requirements); return false; }; @@ -411,7 +413,7 @@ InferredGenericSignatureRequestRQM::evaluate( ->castTo(); genericParams.push_back(gpType); - realizeInheritedRequirements(gpDecl, gpType, /*infer=*/true, + realizeInheritedRequirements(gpDecl, gpType, parentModule, requirements); } @@ -438,7 +440,7 @@ InferredGenericSignatureRequestRQM::evaluate( auto *typeRepr = sourcePair.getTypeRepr(); auto loc = typeRepr ? typeRepr->getStartLoc() : SourceLoc(); - inferRequirements(sourcePair.getType(), loc, requirements); + inferRequirements(sourcePair.getType(), loc, parentModule, requirements); } // Finish by adding any remaining requirements. This is used to introduce @@ -449,7 +451,7 @@ InferredGenericSignatureRequestRQM::evaluate( // Heap-allocate the requirement machine to save stack space. std::unique_ptr machine(new RequirementMachine( - parentModule->getASTContext().getRewriteContext())); + ctx.getRewriteContext())); machine->initWithWrittenRequirements(genericParams, requirements); From 77d4a207f86a3ae28e3486ee39e05fbec9c41036 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 19 Nov 2021 17:04:40 -0500 Subject: [PATCH 9/9] RequirementMachine: Implement requirement inference --- .../RequirementLowering.cpp | 118 +++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index c5e9aec2580a7..d0edad46f0f0e 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -210,6 +210,115 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType, result.push_back({req, loc, /*wasInferred=*/false}); } +namespace { + +/// AST walker that infers requirements from type representations. +struct InferRequirementsWalker : public TypeWalker { + ModuleDecl *module; + SmallVector reqs; + + explicit InferRequirementsWalker(ModuleDecl *module) : module(module) {} + + Action walkToTypePre(Type ty) override { + // Unbound generic types are the result of recovered-but-invalid code, and + // don't have enough info to do any useful substitutions. + if (ty->is()) + return Action::Stop; + + return Action::Continue; + } + + Action walkToTypePost(Type ty) override { + // Infer from generic typealiases. + if (auto typeAlias = dyn_cast(ty.getPointer())) { + auto decl = typeAlias->getDecl(); + auto subMap = typeAlias->getSubstitutionMap(); + for (const auto &rawReq : decl->getGenericSignature().getRequirements()) { + if (auto req = rawReq.subst(subMap)) + desugarRequirement(*req, reqs); + } + + return Action::Continue; + } + + // Infer requirements from `@differentiable` function types. + // For all non-`@noDerivative` parameter and result types: + // - `@differentiable`, `@differentiable(_forward)`, or + // `@differentiable(reverse)`: add `T: Differentiable` requirement. + // - `@differentiable(_linear)`: add + // `T: Differentiable`, `T == T.TangentVector` requirements. + if (auto *fnTy = ty->getAs()) { + auto &ctx = module->getASTContext(); + auto *differentiableProtocol = + ctx.getProtocol(KnownProtocolKind::Differentiable); + if (differentiableProtocol && fnTy->isDifferentiable()) { + auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) { + Requirement req(RequirementKind::Conformance, type, + protocol->getDeclaredInterfaceType()); + desugarRequirement(req, reqs); + }; + auto addSameTypeConstraint = [&](Type firstType, + AssociatedTypeDecl *assocType) { + auto *protocol = assocType->getProtocol(); + auto *module = protocol->getParentModule(); + auto conf = module->lookupConformance(firstType, protocol); + auto secondType = conf.getAssociatedType( + firstType, assocType->getDeclaredInterfaceType()); + Requirement req(RequirementKind::SameType, firstType, secondType); + desugarRequirement(req, reqs); + }; + auto *tangentVectorAssocType = + differentiableProtocol->getAssociatedType(ctx.Id_TangentVector); + auto addRequirements = [&](Type type, bool isLinear) { + addConformanceConstraint(type, differentiableProtocol); + if (isLinear) + addSameTypeConstraint(type, tangentVectorAssocType); + }; + auto constrainParametersAndResult = [&](bool isLinear) { + for (auto ¶m : fnTy->getParams()) + if (!param.isNoDerivative()) + addRequirements(param.getPlainType(), isLinear); + addRequirements(fnTy->getResult(), isLinear); + }; + // Add requirements. + constrainParametersAndResult(fnTy->getDifferentiabilityKind() == + DifferentiabilityKind::Linear); + } + } + + if (!ty->isSpecialized()) + return Action::Continue; + + // Infer from generic nominal types. + auto decl = ty->getAnyNominal(); + if (!decl) return Action::Continue; + + // FIXME: The GSB and the request evaluator both detect a cycle here if we + // force a recursive generic signature. We should look into moving cycle + // detection into the generic signature request(s) - see rdar://55263708 + if (!decl->hasComputedGenericSignature()) + return Action::Continue; + + auto genericSig = decl->getGenericSignature(); + if (!genericSig) + return Action::Continue; + + /// Retrieve the substitution. + auto subMap = ty->getContextSubstitutionMap(module, decl); + + // Handle the requirements. + // FIXME: Inaccurate TypeReprs. + for (const auto &rawReq : genericSig.getRequirements()) { + if (auto req = rawReq.subst(subMap)) + desugarRequirement(*req, reqs); + } + + return Action::Continue; + } +}; + +} + /// Infer requirements from applications of BoundGenericTypes to type /// parameters. For example, given a function declaration /// @@ -220,7 +329,14 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType, void swift::rewriting::inferRequirements( Type type, SourceLoc loc, ModuleDecl *module, SmallVectorImpl &result) { - // FIXME: Implement + if (!type) + return; + + InferRequirementsWalker walker(module); + type.walk(walker); + + for (const auto &req : walker.reqs) + result.push_back({req, loc, /*wasInferred=*/true}); } /// Desugar a requirement and perform requirement inference if requested