From 0c777e7534a5f3603401bb5b42f982e534499944 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 14:40:01 -0500 Subject: [PATCH 01/25] RequirementMachine: Split up PropertyMap::addProperty() --- lib/AST/RequirementMachine/PropertyMap.h | 5 + .../PropertyUnification.cpp | 201 ++++++++++-------- 2 files changed, 118 insertions(+), 88 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyMap.h b/lib/AST/RequirementMachine/PropertyMap.h index e8b4156248129..9cd6a97449848 100644 --- a/lib/AST/RequirementMachine/PropertyMap.h +++ b/lib/AST/RequirementMachine/PropertyMap.h @@ -245,6 +245,11 @@ class PropertyMap { void addProperty(Term key, Symbol property, unsigned ruleID); + void addConformanceProperty(Term key, Symbol property, unsigned ruleID); + void addLayoutProperty(Term key, Symbol property, unsigned ruleID); + void addSuperclassProperty(Term key, Symbol property, unsigned ruleID); + void addConcreteTypeProperty(Term key, Symbol property, unsigned ruleID); + void checkConcreteTypeRequirements(); void concretizeNestedTypesFromConcreteParents(); diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 58c6c8bf1992f..be36fe9542917 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -351,117 +351,142 @@ static std::pair unifySuperclasses( return std::make_pair(rhs, false); } -/// Record a protocol conformance, layout or superclass constraint on the given -/// key. Must be called in monotonically non-decreasing key order. -void PropertyMap::addProperty( +void PropertyMap::addConformanceProperty( + Term key, Symbol property, unsigned ruleID) { + auto *props = getOrCreateProperties(key); + props->ConformsTo.push_back(property.getProtocol()); + props->ConformsToRules.push_back(ruleID); +} + +void PropertyMap::addLayoutProperty( Term key, Symbol property, unsigned ruleID) { - assert(property.isProperty()); - assert(*System.getRule(ruleID).isPropertyRule() == property); auto *props = getOrCreateProperties(key); bool debug = Debug.contains(DebugFlags::ConcreteUnification); - switch (property.getKind()) { - case Symbol::Kind::Protocol: - props->ConformsTo.push_back(property.getProtocol()); - props->ConformsToRules.push_back(ruleID); + auto newLayout = property.getLayoutConstraint(); + + if (!props->Layout) { + // If we haven't seen a layout requirement before, just record it. + props->Layout = newLayout; + props->LayoutRule = ruleID; return; + } - case Symbol::Kind::Layout: { - auto newLayout = property.getLayoutConstraint(); - - if (!props->Layout) { - // If we haven't seen a layout requirement before, just record it. - props->Layout = newLayout; - props->LayoutRule = ruleID; - } else { - // Otherwise, compute the intersection. - assert(props->LayoutRule.hasValue()); - auto mergedLayout = props->Layout.merge(property.getLayoutConstraint()); - - // If the intersection is invalid, we have a conflict. - if (!mergedLayout->isKnownLayout()) { - recordConflict(key, *props->LayoutRule, ruleID, System); - return; - } + // Otherwise, compute the intersection. + assert(props->LayoutRule.hasValue()); + auto mergedLayout = props->Layout.merge(property.getLayoutConstraint()); - // If the intersection is equal to the existing layout requirement, - // the new layout requirement is redundant. - if (mergedLayout == props->Layout) { - if (checkRulePairOnce(*props->LayoutRule, ruleID)) { - recordRelation(key, *props->LayoutRule, property, System, debug); - } + // If the intersection is invalid, we have a conflict. + if (!mergedLayout->isKnownLayout()) { + recordConflict(key, *props->LayoutRule, ruleID, System); + return; + } - // If the intersection is equal to the new layout requirement, the - // existing layout requirement is redundant. - } else if (mergedLayout == newLayout) { - if (checkRulePairOnce(ruleID, *props->LayoutRule)) { - auto oldProperty = System.getRule(*props->LayoutRule).getLHS().back(); - recordRelation(key, ruleID, oldProperty, System, debug); - } + // If the intersection is equal to the existing layout requirement, + // the new layout requirement is redundant. + if (mergedLayout == props->Layout) { + if (checkRulePairOnce(*props->LayoutRule, ruleID)) { + recordRelation(key, *props->LayoutRule, property, System, debug); + } - props->LayoutRule = ruleID; - } else { - llvm::errs() << "Arbitrary intersection of layout requirements is " - << "supported yet\n"; - abort(); - } + // If the intersection is equal to the new layout requirement, the + // existing layout requirement is redundant. + } else if (mergedLayout == newLayout) { + if (checkRulePairOnce(ruleID, *props->LayoutRule)) { + auto oldProperty = System.getRule(*props->LayoutRule).getLHS().back(); + recordRelation(key, ruleID, oldProperty, System, debug); } - return; + props->LayoutRule = ruleID; + } else { + llvm::errs() << "Arbitrary intersection of layout requirements is " + << "supported yet\n"; + abort(); } +} - case Symbol::Kind::Superclass: { - if (checkRuleOnce(ruleID)) { - // A rule (T.[superclass: C] => T) induces a rule (T.[layout: L] => T), - // where L is either AnyObject or _NativeObject. - auto superclass = - property.getConcreteType()->getClassOrBoundGenericClass(); - auto layout = - LayoutConstraint::getLayoutConstraint( - superclass->getLayoutConstraintKind(), - Context.getASTContext()); - auto layoutSymbol = Symbol::forLayout(layout, Context); - - recordRelation(key, ruleID, layoutSymbol, System, debug); - } +void PropertyMap::addSuperclassProperty( + Term key, Symbol property, unsigned ruleID) { + auto *props = getOrCreateProperties(key); + bool debug = Debug.contains(DebugFlags::ConcreteUnification); - if (!props->Superclass) { - props->Superclass = property; - props->SuperclassRule = ruleID; - } else { - assert(props->SuperclassRule.hasValue()); - auto pair = unifySuperclasses(*props->Superclass, property, - System, debug); - props->Superclass = pair.first; - bool conflict = pair.second; - if (conflict) { - recordConflict(key, *props->SuperclassRule, ruleID, System); - return; - } - } + if (checkRuleOnce(ruleID)) { + // A rule (T.[superclass: C] => T) induces a rule (T.[layout: L] => T), + // where L is either AnyObject or _NativeObject. + auto superclass = + property.getConcreteType()->getClassOrBoundGenericClass(); + auto layout = + LayoutConstraint::getLayoutConstraint( + superclass->getLayoutConstraintKind(), + Context.getASTContext()); + auto layoutSymbol = Symbol::forLayout(layout, Context); + + recordRelation(key, ruleID, layoutSymbol, System, debug); + } + if (!props->Superclass) { + props->Superclass = property; + props->SuperclassRule = ruleID; return; } - case Symbol::Kind::ConcreteType: { - if (!props->ConcreteType) { - props->ConcreteType = property; - props->ConcreteTypeRule = ruleID; - } else { - assert(props->ConcreteTypeRule.hasValue()); - bool conflict = unifyConcreteTypes(*props->ConcreteType, property, - System, debug); - if (conflict) { - recordConflict(key, *props->ConcreteTypeRule, ruleID, System); - return; - } - } + assert(props->SuperclassRule.hasValue()); + auto pair = unifySuperclasses(*props->Superclass, property, + System, debug); + props->Superclass = pair.first; + bool conflict = pair.second; + if (conflict) { + recordConflict(key, *props->SuperclassRule, ruleID, System); + } +} +void PropertyMap::addConcreteTypeProperty( + Term key, Symbol property, unsigned ruleID) { + auto *props = getOrCreateProperties(key); + bool debug = Debug.contains(DebugFlags::ConcreteUnification); + + if (!props->ConcreteType) { + props->ConcreteType = property; + props->ConcreteTypeRule = ruleID; return; } + assert(props->ConcreteTypeRule.hasValue()); + bool conflict = unifyConcreteTypes(*props->ConcreteType, property, + System, debug); + if (conflict) { + recordConflict(key, *props->ConcreteTypeRule, ruleID, System); + } +} + +/// Record a protocol conformance, layout or superclass constraint on the given +/// key. Must be called in monotonically non-decreasing key order. +void PropertyMap::addProperty( + Term key, Symbol property, unsigned ruleID) { + assert(property.isProperty()); + assert(*System.getRule(ruleID).isPropertyRule() == property); + + switch (property.getKind()) { + case Symbol::Kind::Protocol: + addConformanceProperty(key, property, ruleID); + return; + + case Symbol::Kind::Layout: + addLayoutProperty(key, property, ruleID); + return; + + case Symbol::Kind::Superclass: + addSuperclassProperty(key, property, ruleID); + return; + + case Symbol::Kind::ConcreteType: + addConcreteTypeProperty(key, property, ruleID); + return; + case Symbol::Kind::ConcreteConformance: - // FIXME + // Concrete conformance rules are not recorded in the property map, since + // they're not needed for unification, and generic signature queries don't + // care about them. return; case Symbol::Kind::Name: From b28d35fa224ca316163198991fbabe212bc07973 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Tue, 1 Feb 2022 21:59:12 -0500 Subject: [PATCH 02/25] RequirementMachine: Introduce RewriteSystem::computeTypeDifference() --- lib/AST/CMakeLists.txt | 1 + lib/AST/RequirementMachine/RewriteSystem.cpp | 6 + lib/AST/RequirementMachine/RewriteSystem.h | 22 + lib/AST/RequirementMachine/TypeDifference.cpp | 514 ++++++++++++++++++ lib/AST/RequirementMachine/TypeDifference.h | 65 +++ 5 files changed, 608 insertions(+) create mode 100644 lib/AST/RequirementMachine/TypeDifference.cpp create mode 100644 lib/AST/RequirementMachine/TypeDifference.h diff --git a/lib/AST/CMakeLists.txt b/lib/AST/CMakeLists.txt index 1b443185cdfc9..00e6121f56757 100644 --- a/lib/AST/CMakeLists.txt +++ b/lib/AST/CMakeLists.txt @@ -91,6 +91,7 @@ add_swift_host_library(swiftAST STATIC RequirementMachine/RewriteSystem.cpp RequirementMachine/Symbol.cpp RequirementMachine/Term.cpp + RequirementMachine/TypeDifference.cpp SearchPathOptions.cpp SILLayout.cpp Stmt.cpp diff --git a/lib/AST/RequirementMachine/RewriteSystem.cpp b/lib/AST/RequirementMachine/RewriteSystem.cpp index de0fec1af4679..dfd6036ed95b8 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.cpp +++ b/lib/AST/RequirementMachine/RewriteSystem.cpp @@ -731,6 +731,12 @@ void RewriteSystem::dump(llvm::raw_ostream &out) const { out << "- " << relation.first << " =>> " << relation.second << "\n"; } out << "}\n"; + out << "Type differences: {\n"; + for (const auto &difference : Differences) { + difference.dump(out); + out << "\n"; + } + out << "}\n"; out << "Rewrite loops: {\n"; for (const auto &loop : Loops) { if (loop.isDeleted()) diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 0bd873db70d67..433b38a951ce8 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -21,6 +21,7 @@ #include "Symbol.h" #include "Term.h" #include "Trie.h" +#include "TypeDifference.h" namespace llvm { class raw_ostream; @@ -389,6 +390,9 @@ class RewriteSystem final { using Relation = std::pair; private: + /// The map's values are indices into the vector. The map is used for + /// uniquing, then the index is returned and lookups are performed into + /// the vector. llvm::DenseMap RelationMap; std::vector Relations; @@ -411,6 +415,24 @@ class RewriteSystem final { Symbol concreteConformanceSymbol, Symbol associatedTypeSymbol); +private: + /// The map's values are indices into the vector. The map is used for + /// uniquing, then the index is returned and lookups are performed into + /// the vector. + llvm::DenseMap, unsigned> DifferenceMap; + std::vector Differences; + + unsigned recordTypeDifference(Symbol lhs, Symbol rhs, + const TypeDifference &difference); + +public: + bool + computeTypeDifference(Symbol lhs, Symbol rhs, + Optional &lhsDifferenceID, + Optional &rhsDifferenceID); + + const TypeDifference &getTypeDifference(unsigned index) const; + private: ////////////////////////////////////////////////////////////////////////////// /// diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp new file mode 100644 index 0000000000000..3f0c0ac42caac --- /dev/null +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -0,0 +1,514 @@ +//===--- TypeDifference.cpp - Utility for concrete type unification -------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2021 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#include "TypeDifference.h" +#include "swift/AST/Types.h" +#include "swift/AST/TypeMatcher.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "RewriteContext.h" +#include "RewriteSystem.h" +#include "Term.h" + +using namespace swift; +using namespace rewriting; + +void TypeDifference::dump(llvm::raw_ostream &out) const { + llvm::errs() << "LHS: " << LHS << "\n"; + llvm::errs() << "RHS: " << RHS << "\n"; + + for (const auto &pair : SameTypes) { + out << "- " << LHS.getSubstitutions()[pair.first] << " (#"; + out << pair.first << ") -> "; + out << RHS.getSubstitutions()[pair.second] << " (#"; + out << pair.second << ")\n"; + } + + for (const auto &pair : ConcreteTypes) { + out << "- " << LHS.getSubstitutions()[pair.first] << " (#"; + out << pair.first << ") -> " << pair.second << "\n"; + } +} + +void TypeDifference::verify(RewriteContext &ctx) const { +#ifndef NDEBUG + +#define VERIFY(expr, str) \ + if (!(expr)) { \ + llvm::errs() << "TypeDifference::verify(): " << str << "\n"; \ + dump(llvm::errs()); \ + abort(); \ + } + + VERIFY(LHS.getKind() == RHS.getKind(), "Kind mismatch"); + + if (LHS == RHS) { + VERIFY(SameTypes.empty(), "Abstract substitutions with equal symbols"); + VERIFY(ConcreteTypes.empty(), "Concrete substitutions with equal symbols"); + } else { + VERIFY(!SameTypes.empty() || !ConcreteTypes.empty(), + "Missing substitutions with non-equal symbols"); + + llvm::DenseSet lhsVisited; + llvm::DenseSet rhsVisited; + + for (const auto &pair : SameTypes) { + auto first = LHS.getSubstitutions()[pair.first]; + auto second = RHS.getSubstitutions()[pair.second]; + VERIFY(first.compare(second, ctx) > 0, "Order violation"); + + VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitutions"); + VERIFY(rhsVisited.insert(pair.second).second, "Duplicate substitutions"); + } + + for (const auto &pair : ConcreteTypes) { + VERIFY(pair.first < LHS.getSubstitutions().size(), + "Out-of-bounds substitution"); + VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitutions"); + VERIFY(pair.second.getKind() == Symbol::Kind::ConcreteType, "Bad kind"); + } + } + +#undef VERIFY +#endif +} + +namespace { + class ConcreteTypeMatcher : public TypeMatcher { + ArrayRef LHSSubstitutions; + ArrayRef RHSSubstitutions; + RewriteContext &Context; + + public: + /// Mismatches where both sides are type parameters and the left hand + /// side orders before the right hand side. The integers index the + /// LHSSubstitutions and RHSSubstitutions arrays, respectively. + SmallVector, 1> SameTypesOnLHS; + + /// Mismatches where both sides are type parameters and the left hand + /// side orders after the right hand side. The integers index the + /// RHSSubstitutions and LHSSubstitutions arrays, respectively. + SmallVector, 1> SameTypesOnRHS; + + /// Mismatches where the left hand side is concrete and the right hand + /// side is a type parameter. The integer is an index into the + /// RHSSubstitutions array. + SmallVector, 1> ConcreteTypesOnLHS; + + /// Mismatches where the right hand side is concrete and the left hand + /// side is a type parameter. The integer is an index into the + /// LHSSubstitutions array. + SmallVector, 1> ConcreteTypesOnRHS; + + /// Mismatches where both sides are concrete; the presence of at least + /// one such mismatch indicates a conflict. + SmallVector, 1> ConcreteConflicts; + + ConcreteTypeMatcher(ArrayRef lhsSubstitutions, + ArrayRef rhsSubstitutions, + RewriteContext &ctx) + : LHSSubstitutions(lhsSubstitutions), + RHSSubstitutions(rhsSubstitutions), + Context(ctx) {} + + bool alwaysMismatchTypeParameters() const { return true; } + + bool mismatch(TypeBase *lhsType, TypeBase *rhsType, + Type sugaredFirstType) { + bool lhsAbstract = lhsType->isTypeParameter(); + bool rhsAbstract = rhsType->isTypeParameter(); + + if (lhsAbstract && rhsAbstract) { + unsigned lhsIndex = RewriteContext::getGenericParamIndex(lhsType); + unsigned rhsIndex = RewriteContext::getGenericParamIndex(rhsType); + + auto lhsTerm = LHSSubstitutions[lhsIndex]; + auto rhsTerm = RHSSubstitutions[rhsIndex]; + + int compare = lhsTerm.compare(rhsTerm, Context); + if (compare < 0) { + SameTypesOnLHS.emplace_back(rhsIndex, lhsIndex); + } else if (compare > 0) { + SameTypesOnRHS.emplace_back(lhsIndex, rhsIndex); + } else { + assert(lhsTerm == rhsTerm); + } + return true; + } + + if (lhsAbstract) { + assert(!rhsAbstract); + unsigned lhsIndex = RewriteContext::getGenericParamIndex(lhsType); + + SmallVector result; + auto rhsSchema = Context.getRelativeSubstitutionSchemaFromType( + CanType(rhsType), RHSSubstitutions, result); + auto rhsSymbol = Symbol::forConcreteType(rhsSchema, result, Context); + + ConcreteTypesOnRHS.emplace_back(lhsIndex, rhsSymbol); + return true; + } + + if (rhsAbstract) { + assert(!lhsAbstract); + unsigned rhsIndex = RewriteContext::getGenericParamIndex(rhsType); + + SmallVector result; + auto lhsSchema = Context.getRelativeSubstitutionSchemaFromType( + CanType(lhsType), LHSSubstitutions, result); + auto lhsSymbol = Symbol::forConcreteType(lhsSchema, result, Context); + + ConcreteTypesOnLHS.emplace_back(rhsIndex, lhsSymbol); + return true; + } + + // Any other kind of type mismatch involves conflicting concrete types on + // both sides, which can only happen on invalid input. + assert(!lhsAbstract && !rhsAbstract); + ConcreteConflicts.emplace_back(CanType(lhsType), CanType(rhsType)); + return true; + } + + void verify() const { +#ifndef NDEBUG + +#define VERIFY(expr, str) \ + if (!(expr)) { \ + llvm::errs() << "ConcreteTypeMatcher::verify(): " << str << "\n"; \ + dump(llvm::errs()); \ + abort(); \ + } + + llvm::DenseSet lhsVisited; + llvm::DenseSet rhsVisited; + + for (const auto &pair : SameTypesOnLHS) { + auto first = RHSSubstitutions[pair.first]; + auto second = LHSSubstitutions[pair.second]; + VERIFY(first.compare(second, Context) > 0, "Order violation"); + + VERIFY(rhsVisited.insert(pair.first).second, "Duplicate substitution"); + VERIFY(lhsVisited.insert(pair.second).second, "Duplicate substitution"); + } + + for (const auto &pair : SameTypesOnRHS) { + auto first = LHSSubstitutions[pair.first]; + auto second = RHSSubstitutions[pair.second]; + VERIFY(first.compare(second, Context) > 0, "Order violation"); + + VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitution"); + VERIFY(rhsVisited.insert(pair.second).second, "Duplicate substitution"); + } + + for (const auto &pair : ConcreteTypesOnLHS) { + VERIFY(pair.first < RHSSubstitutions.size(), + "Out-of-bounds substitution"); + VERIFY(rhsVisited.insert(pair.first).second, "Duplicate substitution"); + } + + for (const auto &pair : ConcreteTypesOnRHS) { + VERIFY(pair.first < LHSSubstitutions.size(), + "Out-of-bounds substitution"); + VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitution"); + } + +#undef VERIFY +#endif + } + + void dump(llvm::raw_ostream &out) const { + out << "Abstract differences with LHS < RHS:\n"; + for (const auto &pair : SameTypesOnLHS) { + out << "- " << RHSSubstitutions[pair.first] << " (#"; + out << pair.first << ") -> "; + out << LHSSubstitutions[pair.second] << " (#"; + out << pair.second << ")\n"; + } + + out << "Abstract differences with RHS < LHS:\n"; + for (const auto &pair : SameTypesOnRHS) { + out << "- " << LHSSubstitutions[pair.first] << " (#"; + out << pair.first << ") -> "; + out << RHSSubstitutions[pair.second] << " (#"; + out << pair.second << ")\n"; + } + + out << "Concrete differences with LHS < RHS:\n"; + for (const auto &pair : ConcreteTypesOnLHS) { + out << "- " << RHSSubstitutions[pair.first] << " (#"; + out << pair.first << ") -> " << pair.second << "\n"; + } + + out << "Concrete differences with RHS < LHS:\n"; + for (const auto &pair : ConcreteTypesOnRHS) { + out << "- " << LHSSubstitutions[pair.first] << " (#"; + out << pair.first << ") -> " << pair.second << "\n"; + } + + out << "Concrete conflicts:\n"; + for (const auto &pair : ConcreteConflicts) { + out << "- " << pair.first << " vs " << pair.second << "\n"; + } + } + }; +} + +static TypeDifference +computeMeet(Symbol symbol, Symbol otherSymbol, + const llvm::SmallVector, 1> &sameTypes, + const llvm::SmallVector, 1> &concreteTypes, + RewriteContext &ctx) { + assert(symbol.getKind() == otherSymbol.getKind()); + + auto &astCtx = ctx.getASTContext(); + + SmallVector resultSubstitutions; + SmallVector, 1> remappedSameTypes; + + auto nextSubstitution = [&](Term t) -> Type { + unsigned index = resultSubstitutions.size(); + resultSubstitutions.push_back(t); + return GenericTypeParamType::get(/*isTypeSequence=*/false, + /*depth=*/0, index, astCtx); + }; + + auto type = symbol.getConcreteType(); + auto substitutions = symbol.getSubstitutions(); + auto otherSubstitutions = otherSymbol.getSubstitutions(); + + Type resultType = type.transformRec([&](Type t) -> Optional { + if (t->is()) { + unsigned index = RewriteContext::getGenericParamIndex(t); + + for (const auto &pair : sameTypes) { + if (pair.first == index) { + remappedSameTypes.emplace_back(pair.first, + resultSubstitutions.size()); + return nextSubstitution(otherSubstitutions[pair.second]); + } + } + + for (const auto &pair : concreteTypes) { + if (pair.first == index) { + auto concreteSymbol = pair.second; + auto concreteType = concreteSymbol.getConcreteType(); + + return concreteType.transformRec([&](Type t) -> Optional { + if (t->is()) { + unsigned index = RewriteContext::getGenericParamIndex(t); + Term substitution = concreteSymbol.getSubstitutions()[index]; + return nextSubstitution(substitution); + } + + assert(!t->is()); + return None; + }); + } + } + + assert(!t->is()); + return nextSubstitution(substitutions[index]); + } + + return None; + }); + + auto resultSymbol = [&]() { + switch (symbol.getKind()) { + case Symbol::Kind::Superclass: + return Symbol::forSuperclass(CanType(resultType), + resultSubstitutions, ctx); + case Symbol::Kind::ConcreteType: + return Symbol::forConcreteType(CanType(resultType), + resultSubstitutions, ctx); + case Symbol::Kind::ConcreteConformance: + assert(symbol.getProtocol() == otherSymbol.getProtocol()); + return Symbol::forConcreteConformance(CanType(resultType), + resultSubstitutions, + symbol.getProtocol(), + ctx); + default: + break; + } + + llvm::report_fatal_error("Bad symbol kind"); + }(); + + return {symbol, resultSymbol, remappedSameTypes, concreteTypes}; +} + +unsigned +RewriteSystem::recordTypeDifference(Symbol lhs, Symbol rhs, + const TypeDifference &difference) { + assert(lhs == difference.LHS); + assert(rhs == difference.RHS); + assert(lhs != rhs); + + auto key = std::make_pair(lhs, rhs); + auto found = DifferenceMap.find(key); + if (found != DifferenceMap.end()) + return found->second; + + unsigned index = Differences.size(); + Differences.push_back(difference); + + auto inserted = DifferenceMap.insert(std::make_pair(key, index)); + assert(inserted.second); + (void) inserted; + + return index; +} + +const TypeDifference &RewriteSystem::getTypeDifference(unsigned index) const { + return Differences[index]; +} + +/// Computes the "meet" (LHS ∧ RHS) of two concrete type symbols (LHS and RHS +/// respectively), together with a set of transformations that turn LHS into +/// (LHS ∧ RHS) and RHS into (LHS ∧ RHS), respectively. +/// +/// Returns 0, 1 or 2 transformations via the two Optional +/// out parameters. The integer is an index that can be passed to +/// RewriteSystem::getTypeDifference() to return a TypeDifference. +/// +/// - If LHS == RHS, both lhsDifference and rhsDifference will be None. +/// +/// - If LHS == (LHS ∧ RHS), then lhsTransform will be None. Otherwise, +/// lhsTransform describes the transform from LHS to (LHS ∧ RHS). +/// +/// - If RHS == (LHS ∧ RHS), then rhsTransform will be None. Otherwise, +/// rhsTransform describes the transform from LHS to (LHS ∧ RHS). +/// +/// - If (LHS ∧ RHS) is distinct from both LHS and RHS, then both +/// lhsTransform and rhsTransform will be populated with a value. +/// +/// Also returns a boolean indicating if there was a concrete type conflict, +/// meaning that LHS and RHS had distinct concrete types at the same +/// position (eg, if LHS == Array and RHS == Array). +/// +/// See the comment at the top of TypeDifference in TypeDifference.h for a +/// description of the actual transformations. +bool +RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, + Optional &lhsDifferenceID, + Optional &rhsDifferenceID) { + assert(lhs.getKind() == rhs.getKind()); + + lhsDifferenceID = None; + rhsDifferenceID = None; + + // Fast path if there's nothing to do. + if (lhs == rhs) + return false; + + // Match the types to find differences. + ConcreteTypeMatcher matcher(lhs.getSubstitutions(), + rhs.getSubstitutions(), + Context); + + bool success = matcher.match(lhs.getConcreteType(), + rhs.getConcreteType()); + assert(success); + (void) success; + + matcher.verify(); + + auto lhsMeetRhs = computeMeet(lhs, rhs, + matcher.SameTypesOnRHS, + matcher.ConcreteTypesOnRHS, + Context); + lhsMeetRhs.verify(Context); + + auto rhsMeetLhs = computeMeet(rhs, lhs, + matcher.SameTypesOnLHS, + matcher.ConcreteTypesOnLHS, + Context); + rhsMeetLhs.verify(Context); + + bool isConflict = (matcher.ConcreteConflicts.size() > 0); + +#ifndef NDEBUG + if (!isConflict) { + // The meet operation should be commutative. + if (lhsMeetRhs.RHS != rhsMeetLhs.RHS) { + llvm::errs() << "Meet operation was not commutative:\n\n"; + + llvm::errs() << "LHS: " << lhs << "\n"; + llvm::errs() << "RHS: " << rhs << "\n"; + matcher.dump(llvm::errs()); + + llvm::errs() << "\n"; + llvm::errs() << "LHS ∧ RHS: " << lhsMeetRhs.RHS << "\n"; + llvm::errs() << "RHS ∧ LHS: " << rhsMeetLhs.RHS << "\n"; + abort(); + } + + // The meet operation should be idempotent. + { + // (LHS ∧ (LHS ∧ RHS)) == (LHS ∧ RHS) + auto lhsMeetLhsMeetRhs = computeMeet(lhs, lhsMeetRhs.RHS, + lhsMeetRhs.SameTypes, + lhsMeetRhs.ConcreteTypes, + Context); + + lhsMeetLhsMeetRhs.verify(Context); + + if (lhsMeetRhs.RHS != lhsMeetLhsMeetRhs.RHS) { + llvm::errs() << "Meet operation was not idempotent:\n\n"; + + llvm::errs() << "LHS: " << lhs << "\n"; + llvm::errs() << "RHS: " << rhs << "\n"; + matcher.dump(llvm::errs()); + + llvm::errs() << "\n"; + llvm::errs() << "LHS ∧ RHS: " << lhsMeetRhs.RHS << "\n"; + llvm::errs() << "LHS ∧ (LHS ∧ RHS): " << lhsMeetLhsMeetRhs.RHS << "\n"; + abort(); + } + } + + { + // (RHS ∧ (RHS ∧ LHS)) == (RHS ∧ LHS) + auto rhsMeetRhsMeetRhs = computeMeet(rhs, rhsMeetLhs.RHS, + rhsMeetLhs.SameTypes, + rhsMeetLhs.ConcreteTypes, + Context); + + rhsMeetRhsMeetRhs.verify(Context); + + if (lhsMeetRhs.RHS != rhsMeetRhsMeetRhs.RHS) { + llvm::errs() << "Meet operation was not idempotent:\n\n"; + + llvm::errs() << "LHS: " << lhs << "\n"; + llvm::errs() << "RHS: " << rhs << "\n"; + matcher.dump(llvm::errs()); + + llvm::errs() << "\n"; + llvm::errs() << "RHS ∧ LHS: " << rhsMeetLhs.RHS << "\n"; + llvm::errs() << "RHS ∧ (RHS ∧ LHS): " << rhsMeetRhsMeetRhs.RHS << "\n"; + abort(); + } + } + } +#endif + + if (lhs != lhsMeetRhs.RHS) + lhsDifferenceID = recordTypeDifference(lhs, lhsMeetRhs.RHS, lhsMeetRhs); + + if (rhs != rhsMeetLhs.RHS) + rhsDifferenceID = recordTypeDifference(rhs, rhsMeetLhs.RHS, rhsMeetLhs); + + return isConflict; +} \ No newline at end of file diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h new file mode 100644 index 0000000000000..250e617285c99 --- /dev/null +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -0,0 +1,65 @@ +//===--- TypeDifference.h - Utility for concrete type unification ---------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2021 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#include "swift/AST/Type.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "Symbol.h" + +#ifndef TYPE_DIFFERENCE_H_ +#define TYPE_DIFFERENCE_H_ + +namespace llvm { + +class raw_ostream; + +} // end namespace llvm + +namespace swift { + +namespace rewriting { + +class RewriteContext; +class Term; + +/// Describes transformations that turn LHS into RHS. There are two kinds of +/// transformations: +/// +/// - Replacing a type term T1 with another type term T2, where T2 < T1. +/// - Replacing a type term T1 with a concrete type C2. +struct TypeDifference { + Symbol LHS; + Symbol RHS; + + /// A pair (N1, N2) where N1 is an index into LHS.getSubstitutions() and + /// N2 is an index into RHS.getSubstitutions(). + SmallVector, 1> SameTypes; + + /// A pair (N1, C2) where N1 is an index into LHS.getSubstitutions() and + /// C2 is a concrete type symbol. + SmallVector, 1> ConcreteTypes; + + TypeDifference(Symbol lhs, Symbol rhs, + SmallVector, 1> sameTypes, + SmallVector, 1> concreteTypes) + : LHS(lhs), RHS(rhs), SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} + + void dump(llvm::raw_ostream &out) const; + void verify(RewriteContext &ctx) const; +}; + +} // end namespace rewriting + +} // end namespace swift + +#endif \ No newline at end of file From 8eb8e8d86dd9ee8c300dc3915a5b23a05bbf839a Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 17:57:34 -0500 Subject: [PATCH 03/25] RequirementMachine: Make RewriteSystem::recordRewriteLoop() public for use by the property map --- lib/AST/RequirementMachine/RewriteSystem.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 433b38a951ce8..078e4acd5fd07 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -456,9 +456,6 @@ class RewriteSystem final { /// algorithms. std::vector Loops; - void recordRewriteLoop(MutableTerm basepoint, - RewritePath path); - void propagateExplicitBits(); Optional @@ -474,6 +471,9 @@ class RewriteSystem final { llvm::DenseSet &redundantConformances); public: + void recordRewriteLoop(MutableTerm basepoint, + RewritePath path); + bool isInMinimizationDomain(ArrayRef protos) const; ArrayRef getLoops() const { From e19ee4d1a166dd99c5553dfbf0cda383d1f00777 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 15:02:30 -0500 Subject: [PATCH 04/25] RequirementMachine: Refactor PropertyMap::addConcreteTypeProperty() to use computeTypeDifference() This doesn't record rewrite loops from most concrete unifications just yet, only handling a case where two concrete types were identical except for an adjustment. --- .../PropertyUnification.cpp | 263 ++++++++++++++---- .../concrete_redundancy_via_adjustment.swift | 14 + 2 files changed, 226 insertions(+), 51 deletions(-) create mode 100644 test/Generics/concrete_redundancy_via_adjustment.swift diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index be36fe9542917..50730dc9923f8 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -223,55 +223,6 @@ namespace { }; } -/// When a type parameter has two concrete types, we have to unify the -/// type constructor arguments. -/// -/// For example, suppose that we have two concrete same-type requirements: -/// -/// T == Foo -/// T == Foo -/// -/// These lower to the following two rules: -/// -/// T.[concrete: Foo<τ_0_0, τ_0_1, String> with {X.Y, Z}] => T -/// T.[concrete: Foo with {A.B, W}] => T -/// -/// The two concrete type symbols will be added to the property bag of 'T', -/// and we will eventually end up in this method, where we will generate three -/// induced rules: -/// -/// X.Y.[concrete: Int] => X.Y -/// A.B => Z -/// W.[concrete: String] => W -/// -/// Returns the left hand side on success (it could also return the right hand -/// side; since we unified the type constructor arguments, it doesn't matter). -/// -/// Returns true if a conflict was detected. -static bool unifyConcreteTypes( - Symbol lhs, Symbol rhs, RewriteSystem &system, - bool debug) { - auto lhsType = lhs.getConcreteType(); - auto rhsType = rhs.getConcreteType(); - - if (debug) { - llvm::dbgs() << "% Unifying " << lhs << " with " << rhs << "\n"; - } - - ConcreteTypeMatcher matcher(lhs.getSubstitutions(), - rhs.getSubstitutions(), - system, debug); - if (!matcher.match(lhsType, rhsType)) { - // FIXME: Diagnose the conflict - if (debug) { - llvm::dbgs() << "%% Concrete type conflict\n"; - } - return true; - } - - return false; -} - /// When a type parameter has two superclasses, we have to both unify the /// type constructor arguments, and record the most derived superclass. /// @@ -440,9 +391,33 @@ void PropertyMap::addSuperclassProperty( } } +/// When a type parameter has two concrete types, we have to unify the +/// type constructor arguments. +/// +/// For example, suppose that we have two concrete same-type requirements: +/// +/// T == Foo +/// T == Foo +/// +/// These lower to the following two rules: +/// +/// T.[concrete: Foo<τ_0_0, τ_0_1, String> with {X.Y, Z}] => T +/// T.[concrete: Foo with {A.B, W}] => T +/// +/// The two concrete type symbols will be added to the property bag of 'T', +/// and we will eventually end up in this method, where we will generate three +/// induced rules: +/// +/// X.Y.[concrete: Int] => X.Y +/// A.B => Z +/// W.[concrete: String] => W void PropertyMap::addConcreteTypeProperty( Term key, Symbol property, unsigned ruleID) { auto *props = getOrCreateProperties(key); + + const auto &rule = System.getRule(ruleID); + assert(rule.getRHS() == key); + bool debug = Debug.contains(DebugFlags::ConcreteUnification); if (!props->ConcreteType) { @@ -452,10 +427,196 @@ void PropertyMap::addConcreteTypeProperty( } assert(props->ConcreteTypeRule.hasValue()); - bool conflict = unifyConcreteTypes(*props->ConcreteType, property, - System, debug); + + if (debug) { + llvm::dbgs() << "% Unifying " << *props->ConcreteType; + llvm::dbgs() << " with " << property << "\n"; + } + + Optional lhsDifferenceID; + Optional rhsDifferenceID; + + bool conflict = System.computeTypeDifference(*props->ConcreteType, property, + lhsDifferenceID, + rhsDifferenceID); + if (conflict) { + // FIXME: Diagnose the conflict + if (debug) { + llvm::dbgs() << "%% Concrete type conflict\n"; + } recordConflict(key, *props->ConcreteTypeRule, ruleID, System); + return; + } + + // Record induced rules from the given type difference. + auto processTypeDifference = [&](const TypeDifference &difference) { + if (debug) { + difference.dump(llvm::dbgs()); + } + + for (const auto &pair : difference.SameTypes) { + // Both sides are type parameters; add a same-type requirement. + MutableTerm lhsTerm(difference.LHS.getSubstitutions()[pair.first]); + MutableTerm rhsTerm(difference.RHS.getSubstitutions()[pair.second]); + + if (debug) { + llvm::dbgs() << "%% Induced rule " << lhsTerm + << " == " << rhsTerm << "\n"; + } + + // FIXME: Need a rewrite path here. + System.addRule(lhsTerm, rhsTerm); + } + + for (const auto &pair : difference.ConcreteTypes) { + // A type parameter is equated with a concrete type; add a concrete + // type requirement. + MutableTerm rhsTerm(difference.LHS.getSubstitutions()[pair.first]); + MutableTerm lhsTerm(rhsTerm); + lhsTerm.add(pair.second); + + if (debug) { + llvm::dbgs() << "%% Induced rule " << lhsTerm + << " == " << rhsTerm << "\n"; + } + + // FIXME: Need a rewrite path here. + System.addRule(lhsTerm, rhsTerm); + } + }; + + // Handle the case where (LHS ∧ RHS) is distinct from both LHS and RHS: + // - First, record a new rule. + // - Next, process the LHS -> (LHS ∧ RHS) difference. + // - Finally, process the RHS -> (LHS ∧ RHS) difference. + if (lhsDifferenceID && rhsDifferenceID) { + const auto &lhsDifference = System.getTypeDifference(*lhsDifferenceID); + const auto &rhsDifference = System.getTypeDifference(*rhsDifferenceID); + + auto newProperty = lhsDifference.RHS; + assert(newProperty == rhsDifference.RHS); + + MutableTerm rhsTerm(key); + MutableTerm lhsTerm(key); + lhsTerm.add(newProperty); + + if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) { + assert(lhsDifference.RHS == rhsDifference.RHS); + + if (debug) { + llvm::dbgs() << "%% Induced rule " << lhsTerm + << " == " << rhsTerm << "\n"; + } + + System.addRule(lhsTerm, rhsTerm); + } + + // Recover the (LHS ∧ RHS) rule. + RewritePath path; + bool simplified = System.simplify(lhsTerm, &path); + assert(simplified); + (void) simplified; + + assert(path.size() == 1); + assert(path.begin()->Kind == RewriteStep::Rule); + + unsigned newRuleID = path.begin()->getRuleID(); + + // Process LHS -> (LHS ∧ RHS). + if (checkRulePairOnce(*props->ConcreteTypeRule, newRuleID)) + processTypeDifference(lhsDifference); + + // Process RHS -> (LHS ∧ RHS). + if (checkRulePairOnce(ruleID, newRuleID)) + processTypeDifference(rhsDifference); + + // The new property is more specific, so update ConcreteType and + // ConcreteTypeRule. + props->ConcreteType = newProperty; + props->ConcreteTypeRule = ruleID; + + return; + } + + // Handle the case where RHS == (LHS ∧ RHS) by processing LHS -> (LHS ∧ RHS). + if (lhsDifferenceID) { + assert(!rhsDifferenceID); + + const auto &lhsDifference = System.getTypeDifference(*lhsDifferenceID); + assert(*props->ConcreteType == lhsDifference.LHS); + assert(property == lhsDifference.RHS); + + if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) + processTypeDifference(lhsDifference); + + // The new property is more specific, so update ConcreteType and + // ConcreteTypeRule. + props->ConcreteType = property; + props->ConcreteTypeRule = ruleID; + + return; + } + + // Handle the case where LHS == (LHS ∧ RHS) by processing LHS -> (LHS ∧ RHS). + if (rhsDifferenceID) { + assert(!lhsDifferenceID); + + const auto &rhsDifference = System.getTypeDifference(*rhsDifferenceID); + assert(property == rhsDifference.LHS); + assert(*props->ConcreteType == rhsDifference.RHS); + + if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) + processTypeDifference(rhsDifference); + + // The new property is less specific, so ConcreteType and ConcreteTypeRule + // remain unchanged. + return; + } + + assert(property == *props->ConcreteType); + + if (*props->ConcreteTypeRule != ruleID) { + // If the rules are different but the concrete types are identical, then + // the key is some term U.V, the existing rule is a rule of the form: + // + // V.[concrete: G<...> with ] + // + // and the new rule is a rule of the form: + // + // U.V.[concrete: G<...> with ] + // + // Record a loop relating the two rules via a concrete type adjustment. + // Since the new rule appears without context, it becomes redundant. + if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) { + const auto &otherRule = System.getRule(*props->ConcreteTypeRule); + assert(otherRule.getRHS().size() < key.size()); + + unsigned adjustment = (key.size() - otherRule.getRHS().size()); + + // Build a loop that rewrites U.V back into itself via the two rules, + // with a concrete type adjustment in the middle. + RewritePath path; + + // Add a rewrite step U.(V => V.[concrete: G<...> with ]). + path.add(RewriteStep::forRewriteRule(/*startOffset=*/adjustment, + /*endOffset=*/0, + *props->ConcreteTypeRule, + /*inverse=*/true)); + + // Add a concrete type adjustment. + path.add(RewriteStep::forAdjustment(/*startOffset=*/adjustment, + /*endOffset=*/0, + /*inverse=*/false)); + + // Add a rewrite step (U.V.[concrete: G<...> with ] => U.V). + path.add(RewriteStep::forRewriteRule(/*startOffset=*/0, + /*endOffset=*/0, + ruleID, + /*inverse=*/false)); + + System.recordRewriteLoop(MutableTerm(key), path); + } } } diff --git a/test/Generics/concrete_redundancy_via_adjustment.swift b/test/Generics/concrete_redundancy_via_adjustment.swift new file mode 100644 index 0000000000000..a5fe929ab9aa3 --- /dev/null +++ b/test/Generics/concrete_redundancy_via_adjustment.swift @@ -0,0 +1,14 @@ +// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -requirement-machine-protocol-signatures=on 2>&1 | %FileCheck %s + +struct G {} + +protocol P1 { + associatedtype X where X == G + associatedtype Y +} + +// CHECK-LABEL: concrete_redundancy_via_adjustment.(file).P2@ +// CHECK-NEXT: Requirement signature: +protocol P2 { + associatedtype T : P1 where T.X == G +} From 73296edc639559c1c5a5b03c6a7a14f9b82fe76f Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 22:41:04 -0500 Subject: [PATCH 05/25] RequirementMachine: Less indirect TypeDifference representation --- .../PropertyUnification.cpp | 2 +- lib/AST/RequirementMachine/TypeDifference.cpp | 99 ++++++++----------- lib/AST/RequirementMachine/TypeDifference.h | 6 +- 3 files changed, 43 insertions(+), 64 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 50730dc9923f8..55bbb67214836 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -458,7 +458,7 @@ void PropertyMap::addConcreteTypeProperty( for (const auto &pair : difference.SameTypes) { // Both sides are type parameters; add a same-type requirement. MutableTerm lhsTerm(difference.LHS.getSubstitutions()[pair.first]); - MutableTerm rhsTerm(difference.RHS.getSubstitutions()[pair.second]); + MutableTerm rhsTerm(pair.second); if (debug) { llvm::dbgs() << "%% Induced rule " << lhsTerm diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index 3f0c0ac42caac..2c1f36fefb9c2 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -31,9 +31,7 @@ void TypeDifference::dump(llvm::raw_ostream &out) const { for (const auto &pair : SameTypes) { out << "- " << LHS.getSubstitutions()[pair.first] << " (#"; - out << pair.first << ") -> "; - out << RHS.getSubstitutions()[pair.second] << " (#"; - out << pair.second << ")\n"; + out << pair.first << ") -> " << pair.second << "\n"; } for (const auto &pair : ConcreteTypes) { @@ -62,15 +60,11 @@ void TypeDifference::verify(RewriteContext &ctx) const { "Missing substitutions with non-equal symbols"); llvm::DenseSet lhsVisited; - llvm::DenseSet rhsVisited; for (const auto &pair : SameTypes) { auto first = LHS.getSubstitutions()[pair.first]; - auto second = RHS.getSubstitutions()[pair.second]; - VERIFY(first.compare(second, ctx) > 0, "Order violation"); - + VERIFY(first.compare(pair.second, ctx) > 0, "Order violation"); VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitutions"); - VERIFY(rhsVisited.insert(pair.second).second, "Duplicate substitutions"); } for (const auto &pair : ConcreteTypes) { @@ -93,14 +87,14 @@ namespace { public: /// Mismatches where both sides are type parameters and the left hand - /// side orders before the right hand side. The integers index the - /// LHSSubstitutions and RHSSubstitutions arrays, respectively. - SmallVector, 1> SameTypesOnLHS; + /// side orders before the right hand side. The integer is an index + /// into the LHSSubstitutions array. + SmallVector, 1> SameTypesOnLHS; /// Mismatches where both sides are type parameters and the left hand - /// side orders after the right hand side. The integers index the - /// RHSSubstitutions and LHSSubstitutions arrays, respectively. - SmallVector, 1> SameTypesOnRHS; + /// side orders after the right hand side. The integer is an index + /// into the RHSSubstitutions array. + SmallVector, 1> SameTypesOnRHS; /// Mismatches where the left hand side is concrete and the right hand /// side is a type parameter. The integer is an index into the @@ -139,9 +133,9 @@ namespace { int compare = lhsTerm.compare(rhsTerm, Context); if (compare < 0) { - SameTypesOnLHS.emplace_back(rhsIndex, lhsIndex); + SameTypesOnLHS.emplace_back(rhsIndex, lhsTerm); } else if (compare > 0) { - SameTypesOnRHS.emplace_back(lhsIndex, rhsIndex); + SameTypesOnRHS.emplace_back(lhsIndex, rhsTerm); } else { assert(lhsTerm == rhsTerm); } @@ -196,20 +190,16 @@ namespace { for (const auto &pair : SameTypesOnLHS) { auto first = RHSSubstitutions[pair.first]; - auto second = LHSSubstitutions[pair.second]; - VERIFY(first.compare(second, Context) > 0, "Order violation"); + VERIFY(first.compare(pair.second, Context) > 0, "Order violation"); VERIFY(rhsVisited.insert(pair.first).second, "Duplicate substitution"); - VERIFY(lhsVisited.insert(pair.second).second, "Duplicate substitution"); } for (const auto &pair : SameTypesOnRHS) { auto first = LHSSubstitutions[pair.first]; - auto second = RHSSubstitutions[pair.second]; - VERIFY(first.compare(second, Context) > 0, "Order violation"); + VERIFY(first.compare(pair.second, Context) > 0, "Order violation"); VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitution"); - VERIFY(rhsVisited.insert(pair.second).second, "Duplicate substitution"); } for (const auto &pair : ConcreteTypesOnLHS) { @@ -232,17 +222,13 @@ namespace { out << "Abstract differences with LHS < RHS:\n"; for (const auto &pair : SameTypesOnLHS) { out << "- " << RHSSubstitutions[pair.first] << " (#"; - out << pair.first << ") -> "; - out << LHSSubstitutions[pair.second] << " (#"; - out << pair.second << ")\n"; + out << pair.first << ") -> " << pair.second << "\n"; } out << "Abstract differences with RHS < LHS:\n"; for (const auto &pair : SameTypesOnRHS) { out << "- " << LHSSubstitutions[pair.first] << " (#"; - out << pair.first << ") -> "; - out << RHSSubstitutions[pair.second] << " (#"; - out << pair.second << ")\n"; + out << pair.first << ") -> " << pair.second << "\n"; } out << "Concrete differences with LHS < RHS:\n"; @@ -266,16 +252,14 @@ namespace { } static TypeDifference -computeMeet(Symbol symbol, Symbol otherSymbol, - const llvm::SmallVector, 1> &sameTypes, - const llvm::SmallVector, 1> &concreteTypes, - RewriteContext &ctx) { - assert(symbol.getKind() == otherSymbol.getKind()); - +buildTypeDifference( + Symbol symbol, + const llvm::SmallVector, 1> &sameTypes, + const llvm::SmallVector, 1> &concreteTypes, + RewriteContext &ctx) { auto &astCtx = ctx.getASTContext(); SmallVector resultSubstitutions; - SmallVector, 1> remappedSameTypes; auto nextSubstitution = [&](Term t) -> Type { unsigned index = resultSubstitutions.size(); @@ -286,18 +270,14 @@ computeMeet(Symbol symbol, Symbol otherSymbol, auto type = symbol.getConcreteType(); auto substitutions = symbol.getSubstitutions(); - auto otherSubstitutions = otherSymbol.getSubstitutions(); Type resultType = type.transformRec([&](Type t) -> Optional { if (t->is()) { unsigned index = RewriteContext::getGenericParamIndex(t); for (const auto &pair : sameTypes) { - if (pair.first == index) { - remappedSameTypes.emplace_back(pair.first, - resultSubstitutions.size()); - return nextSubstitution(otherSubstitutions[pair.second]); - } + if (pair.first == index) + return nextSubstitution(pair.second); } for (const auto &pair : concreteTypes) { @@ -334,7 +314,6 @@ computeMeet(Symbol symbol, Symbol otherSymbol, return Symbol::forConcreteType(CanType(resultType), resultSubstitutions, ctx); case Symbol::Kind::ConcreteConformance: - assert(symbol.getProtocol() == otherSymbol.getProtocol()); return Symbol::forConcreteConformance(CanType(resultType), resultSubstitutions, symbol.getProtocol(), @@ -343,10 +322,10 @@ computeMeet(Symbol symbol, Symbol otherSymbol, break; } - llvm::report_fatal_error("Bad symbol kind"); + llvm_unreachable("Bad symbol kind"); }(); - return {symbol, resultSymbol, remappedSameTypes, concreteTypes}; + return {symbol, resultSymbol, sameTypes, concreteTypes}; } unsigned @@ -425,16 +404,16 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, matcher.verify(); - auto lhsMeetRhs = computeMeet(lhs, rhs, - matcher.SameTypesOnRHS, - matcher.ConcreteTypesOnRHS, - Context); + auto lhsMeetRhs = buildTypeDifference(lhs, + matcher.SameTypesOnRHS, + matcher.ConcreteTypesOnRHS, + Context); lhsMeetRhs.verify(Context); - auto rhsMeetLhs = computeMeet(rhs, lhs, - matcher.SameTypesOnLHS, - matcher.ConcreteTypesOnLHS, - Context); + auto rhsMeetLhs = buildTypeDifference(rhs, + matcher.SameTypesOnLHS, + matcher.ConcreteTypesOnLHS, + Context); rhsMeetLhs.verify(Context); bool isConflict = (matcher.ConcreteConflicts.size() > 0); @@ -458,10 +437,10 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, // The meet operation should be idempotent. { // (LHS ∧ (LHS ∧ RHS)) == (LHS ∧ RHS) - auto lhsMeetLhsMeetRhs = computeMeet(lhs, lhsMeetRhs.RHS, - lhsMeetRhs.SameTypes, - lhsMeetRhs.ConcreteTypes, - Context); + auto lhsMeetLhsMeetRhs = buildTypeDifference(lhs, + lhsMeetRhs.SameTypes, + lhsMeetRhs.ConcreteTypes, + Context); lhsMeetLhsMeetRhs.verify(Context); @@ -481,10 +460,10 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, { // (RHS ∧ (RHS ∧ LHS)) == (RHS ∧ LHS) - auto rhsMeetRhsMeetRhs = computeMeet(rhs, rhsMeetLhs.RHS, - rhsMeetLhs.SameTypes, - rhsMeetLhs.ConcreteTypes, - Context); + auto rhsMeetRhsMeetRhs = buildTypeDifference(rhs, + rhsMeetLhs.SameTypes, + rhsMeetLhs.ConcreteTypes, + Context); rhsMeetRhsMeetRhs.verify(Context); diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h index 250e617285c99..9955d8e135a0e 100644 --- a/lib/AST/RequirementMachine/TypeDifference.h +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -15,6 +15,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "Symbol.h" +#include "Term.h" #ifndef TYPE_DIFFERENCE_H_ #define TYPE_DIFFERENCE_H_ @@ -30,7 +31,6 @@ namespace swift { namespace rewriting { class RewriteContext; -class Term; /// Describes transformations that turn LHS into RHS. There are two kinds of /// transformations: @@ -43,14 +43,14 @@ struct TypeDifference { /// A pair (N1, N2) where N1 is an index into LHS.getSubstitutions() and /// N2 is an index into RHS.getSubstitutions(). - SmallVector, 1> SameTypes; + SmallVector, 1> SameTypes; /// A pair (N1, C2) where N1 is an index into LHS.getSubstitutions() and /// C2 is a concrete type symbol. SmallVector, 1> ConcreteTypes; TypeDifference(Symbol lhs, Symbol rhs, - SmallVector, 1> sameTypes, + SmallVector, 1> sameTypes, SmallVector, 1> concreteTypes) : LHS(lhs), RHS(rhs), SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} From 7464a5f139f7f02ebffba9cf3da5b79f1ec4b062 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 23:24:37 -0500 Subject: [PATCH 06/25] RequirementMachine: Make recordTypeDifference() and buildTypeDifference() public --- lib/AST/RequirementMachine/RewriteSystem.h | 2 +- lib/AST/RequirementMachine/TypeDifference.cpp | 4 ++-- lib/AST/RequirementMachine/TypeDifference.h | 7 +++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 078e4acd5fd07..3f21ab6d17230 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -422,10 +422,10 @@ class RewriteSystem final { llvm::DenseMap, unsigned> DifferenceMap; std::vector Differences; +public: unsigned recordTypeDifference(Symbol lhs, Symbol rhs, const TypeDifference &difference); -public: bool computeTypeDifference(Symbol lhs, Symbol rhs, Optional &lhsDifferenceID, diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index 2c1f36fefb9c2..bc0b39d6c9438 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -251,8 +251,8 @@ namespace { }; } -static TypeDifference -buildTypeDifference( +TypeDifference +swift::rewriting::buildTypeDifference( Symbol symbol, const llvm::SmallVector, 1> &sameTypes, const llvm::SmallVector, 1> &concreteTypes, diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h index 9955d8e135a0e..5d1f94c1b764c 100644 --- a/lib/AST/RequirementMachine/TypeDifference.h +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -58,6 +58,13 @@ struct TypeDifference { void verify(RewriteContext &ctx) const; }; +TypeDifference +buildTypeDifference( + Symbol symbol, + const llvm::SmallVector, 1> &sameTypes, + const llvm::SmallVector, 1> &concreteTypes, + RewriteContext &ctx); + } // end namespace rewriting } // end namespace swift From d8aa79c5e575b7f66f03acec9f5cd0f0030e88cc Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Feb 2022 15:01:26 -0500 Subject: [PATCH 07/25] RequirementMachine: Rename RewriteStep::AdjustConcreteType to ::PrefixSubstitutions --- .../ConcreteTypeWitness.cpp | 14 ++++++------- .../RequirementMachine/HomotopyReduction.cpp | 12 +++++------ lib/AST/RequirementMachine/KnuthBendix.cpp | 8 +++---- .../MinimalConformances.cpp | 2 +- .../PropertyUnification.cpp | 18 +++++++++------- lib/AST/RequirementMachine/RewriteLoop.cpp | 14 ++++++------- lib/AST/RequirementMachine/RewriteLoop.h | 21 +++++++++++-------- 7 files changed, 47 insertions(+), 42 deletions(-) diff --git a/lib/AST/RequirementMachine/ConcreteTypeWitness.cpp b/lib/AST/RequirementMachine/ConcreteTypeWitness.cpp index b20ceb0433548..78408cb509bb7 100644 --- a/lib/AST/RequirementMachine/ConcreteTypeWitness.cpp +++ b/lib/AST/RequirementMachine/ConcreteTypeWitness.cpp @@ -458,17 +458,17 @@ void PropertyMap::recordConcreteConformanceRule( /*ruleID=*/concreteRuleID, /*inverse=*/true)); - // Apply a concrete type adjustment to the concrete symbol if T' is shorter - // than T. + // If T' is a suffix of T, prepend the prefix to the concrete type's + // substitutions. auto concreteSymbol = *concreteRule.isPropertyRule(); - unsigned adjustment = rhs.size() - concreteRule.getRHS().size(); + unsigned prefixLength = rhs.size() - concreteRule.getRHS().size(); - if (adjustment > 0 && + if (prefixLength > 0 && !concreteConformanceSymbol.getSubstitutions().empty()) { - path.add(RewriteStep::forAdjustment(adjustment, /*endOffset=*/1, - /*inverse=*/false)); + path.add(RewriteStep::forPrefixSubstitutions(prefixLength, /*endOffset=*/1, + /*inverse=*/false)); - MutableTerm prefix(rhs.begin(), rhs.begin() + adjustment); + MutableTerm prefix(rhs.begin(), rhs.begin() + prefixLength); concreteSymbol = concreteSymbol.prependPrefixToConcreteSubstitutions( prefix, Context); } diff --git a/lib/AST/RequirementMachine/HomotopyReduction.cpp b/lib/AST/RequirementMachine/HomotopyReduction.cpp index 625f6ee526f7c..271103aafd49d 100644 --- a/lib/AST/RequirementMachine/HomotopyReduction.cpp +++ b/lib/AST/RequirementMachine/HomotopyReduction.cpp @@ -94,7 +94,7 @@ RewriteLoop::findRulesAppearingOnceInEmptyContext( break; } - case RewriteStep::AdjustConcreteType: + case RewriteStep::PrefixSubstitutions: case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: @@ -213,7 +213,7 @@ RewritePath RewritePath::splitCycleAtRule(unsigned ruleID) const { sawRule = true; continue; } - case RewriteStep::AdjustConcreteType: + case RewriteStep::PrefixSubstitutions: case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: @@ -279,7 +279,7 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, break; } - auto adjustStep = [&](RewriteStep newStep) { + auto recontextualizeStep = [&](RewriteStep newStep) { bool inverse = newStep.Inverse ^ step.Inverse; if (newStep.Kind == RewriteStep::Decompose && inverse) { @@ -302,15 +302,15 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, if (step.Inverse) { for (auto newStep : llvm::reverse(path)) - adjustStep(newStep); + recontextualizeStep(newStep); } else { for (auto newStep : path) - adjustStep(newStep); + recontextualizeStep(newStep); } break; } - case RewriteStep::AdjustConcreteType: + case RewriteStep::PrefixSubstitutions: case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: diff --git a/lib/AST/RequirementMachine/KnuthBendix.cpp b/lib/AST/RequirementMachine/KnuthBendix.cpp index 64b7bca440315..60312f416e311 100644 --- a/lib/AST/RequirementMachine/KnuthBendix.cpp +++ b/lib/AST/RequirementMachine/KnuthBendix.cpp @@ -452,15 +452,15 @@ RewriteSystem::computeCriticalPair(ArrayRef::const_iterator from, getRuleID(lhs), /*inverse=*/true)); - // (2) Next, if the right hand side rule ends with a concrete type symbol, - // perform the concrete type adjustment: + // (2) Next, if the right hand side rule ends with a superclass or concrete + // type symbol, remove the prefix 'T' from each substitution in the symbol. // // (σ - T) if (xv.back().hasSubstitutions() && !xv.back().getSubstitutions().empty() && t.size() > 0) { - path.add(RewriteStep::forAdjustment(t.size(), /*endOffset=*/0, - /*inverse=*/true)); + path.add(RewriteStep::forPrefixSubstitutions(t.size(), /*endOffset=*/0, + /*inverse=*/true)); xv.back() = xv.back().prependPrefixToConcreteSubstitutions( t, Context); diff --git a/lib/AST/RequirementMachine/MinimalConformances.cpp b/lib/AST/RequirementMachine/MinimalConformances.cpp index 9397a5e682e82..cc99ef856164d 100644 --- a/lib/AST/RequirementMachine/MinimalConformances.cpp +++ b/lib/AST/RequirementMachine/MinimalConformances.cpp @@ -134,7 +134,7 @@ void RewriteLoop::findProtocolConformanceRules( break; } - case RewriteStep::AdjustConcreteType: + case RewriteStep::PrefixSubstitutions: case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 55bbb67214836..ba12cf386cc91 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -586,28 +586,30 @@ void PropertyMap::addConcreteTypeProperty( // // U.V.[concrete: G<...> with ] // - // Record a loop relating the two rules via a concrete type adjustment. + // Record a loop relating the two rules via a rewrite step to prefix 'U' to + // the symbol's substitutions. + // // Since the new rule appears without context, it becomes redundant. if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) { const auto &otherRule = System.getRule(*props->ConcreteTypeRule); assert(otherRule.getRHS().size() < key.size()); - unsigned adjustment = (key.size() - otherRule.getRHS().size()); + unsigned prefixLength = (key.size() - otherRule.getRHS().size()); // Build a loop that rewrites U.V back into itself via the two rules, - // with a concrete type adjustment in the middle. + // with a prefix substitutions step in the middle. RewritePath path; // Add a rewrite step U.(V => V.[concrete: G<...> with ]). - path.add(RewriteStep::forRewriteRule(/*startOffset=*/adjustment, + path.add(RewriteStep::forRewriteRule(/*startOffset=*/prefixLength, /*endOffset=*/0, *props->ConcreteTypeRule, /*inverse=*/true)); - // Add a concrete type adjustment. - path.add(RewriteStep::forAdjustment(/*startOffset=*/adjustment, - /*endOffset=*/0, - /*inverse=*/false)); + // Add a rewrite step to prefix 'U' to the substitutions. + path.add(RewriteStep::forPrefixSubstitutions(/*length=*/prefixLength, + /*endOffset=*/0, + /*inverse=*/false)); // Add a rewrite step (U.V.[concrete: G<...> with ] => U.V). path.add(RewriteStep::forRewriteRule(/*startOffset=*/0, diff --git a/lib/AST/RequirementMachine/RewriteLoop.cpp b/lib/AST/RequirementMachine/RewriteLoop.cpp index 810d25e83ada2..df6adcca8c894 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.cpp +++ b/lib/AST/RequirementMachine/RewriteLoop.cpp @@ -85,8 +85,8 @@ void RewriteStep::dump(llvm::raw_ostream &out, break; } - case AdjustConcreteType: { - auto pair = evaluator.applyAdjustment(*this, system); + case PrefixSubstitutions: { + auto pair = evaluator.applyPrefixSubstitutions(*this, system); out << "(σ"; out << (Inverse ? " - " : " + "); @@ -260,11 +260,11 @@ RewritePathEvaluator::applyRewriteRule(const RewriteStep &step, } std::pair -RewritePathEvaluator::applyAdjustment(const RewriteStep &step, - const RewriteSystem &system) { +RewritePathEvaluator::applyPrefixSubstitutions(const RewriteStep &step, + const RewriteSystem &system) { auto &term = getCurrentTerm(); - assert(step.Kind == RewriteStep::AdjustConcreteType); + assert(step.Kind == RewriteStep::PrefixSubstitutions); auto &ctx = system.getRewriteContext(); MutableTerm prefix(term.begin() + step.StartOffset, @@ -451,8 +451,8 @@ void RewritePathEvaluator::apply(const RewriteStep &step, (void) applyRewriteRule(step, system); break; - case RewriteStep::AdjustConcreteType: - (void) applyAdjustment(step, system); + case RewriteStep::PrefixSubstitutions: + (void) applyPrefixSubstitutions(step, system); break; case RewriteStep::Shift: diff --git a/lib/AST/RequirementMachine/RewriteLoop.h b/lib/AST/RequirementMachine/RewriteLoop.h index 7d4165728db18..b574aaac373cc 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.h +++ b/lib/AST/RequirementMachine/RewriteLoop.h @@ -62,7 +62,7 @@ struct RewriteStep { /// If inverted: strip the prefix from each substitution. /// /// The StartOffset field encodes the length of the prefix. - AdjustConcreteType, + PrefixSubstitutions, /// /// *** Rewrite step kinds introduced by simplifySubstitutions() *** @@ -130,10 +130,13 @@ struct RewriteStep { /// If Kind is Rule, the index of the rule in the rewrite system. /// - /// If Kind is AdjustConcreteType, the length of the prefix to add or remove + /// If Kind is PrefixSubstitutions, the length of the prefix to add or remove /// at the beginning of each concrete substitution. /// - /// If Kind is Concrete, the number of substitutions to push or pop. + /// If Kind is Decompose, the number of substitutions to push or pop. + /// + /// If Kind is Relation, the relation index returned from + /// RewriteSystem::recordRelation(). unsigned Arg : 16; RewriteStep(StepKind kind, unsigned startOffset, unsigned endOffset, @@ -154,10 +157,10 @@ struct RewriteStep { return RewriteStep(Rule, startOffset, endOffset, ruleID, inverse); } - static RewriteStep forAdjustment(unsigned offset, unsigned endOffset, - bool inverse) { - return RewriteStep(AdjustConcreteType, /*startOffset=*/0, endOffset, - /*arg=*/offset, inverse); + static RewriteStep forPrefixSubstitutions(unsigned length, unsigned endOffset, + bool inverse) { + return RewriteStep(PrefixSubstitutions, /*startOffset=*/0, endOffset, + /*arg=*/length, inverse); } static RewriteStep forShift(bool inverse) { @@ -355,8 +358,8 @@ struct RewritePathEvaluator { const RewriteSystem &system); std::pair - applyAdjustment(const RewriteStep &step, - const RewriteSystem &system); + applyPrefixSubstitutions(const RewriteStep &step, + const RewriteSystem &system); void applyShift(const RewriteStep &step, const RewriteSystem &system); From 7e4a5876d18914d209401e58b4a63d10a06083ab Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Feb 2022 18:18:59 -0500 Subject: [PATCH 08/25] RequirementMachine: Introduce RewriteStep::DecomposeConcrete --- .../RequirementMachine/HomotopyReduction.cpp | 32 ++++-- .../MinimalConformances.cpp | 1 + lib/AST/RequirementMachine/RewriteLoop.cpp | 102 ++++++++++++++++++ lib/AST/RequirementMachine/RewriteLoop.h | 51 ++++++++- 4 files changed, 177 insertions(+), 9 deletions(-) diff --git a/lib/AST/RequirementMachine/HomotopyReduction.cpp b/lib/AST/RequirementMachine/HomotopyReduction.cpp index 271103aafd49d..c8002d7e3ed90 100644 --- a/lib/AST/RequirementMachine/HomotopyReduction.cpp +++ b/lib/AST/RequirementMachine/HomotopyReduction.cpp @@ -98,6 +98,7 @@ RewriteLoop::findRulesAppearingOnceInEmptyContext( case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: + case RewriteStep::DecomposeConcrete: break; } @@ -217,6 +218,7 @@ RewritePath RewritePath::splitCycleAtRule(unsigned ruleID) const { case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: + case RewriteStep::DecomposeConcrete: break; } @@ -265,24 +267,31 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, SmallVector newSteps; - // Keep track of Decompose/Compose pairs. Any rewrite steps in - // between do not need to be re-contextualized, since they - // operate on new terms that were pushed on the stack by the - // Compose operation. - unsigned decomposeCount = 0; - for (const auto &step : Steps) { switch (step.Kind) { case RewriteStep::Rule: { + // All other rewrite rules remain unchanged. if (step.getRuleID() != ruleID) { newSteps.push_back(step); break; } + // Ok, we found a rewrite step referencing the redundant rule. + // Replace this step with the provided path. If this rewrite step has + // context, the path's own steps must be re-contextualized. + + // Keep track of Decompose/DecomposeConcrete pairs. Any rewrite steps + // in between do not need to be re-contextualized, since they operate + // on new terms that were pushed on the stack by the Decompose or + // DecomposeConcrete operation. + unsigned decomposeCount = 0; + auto recontextualizeStep = [&](RewriteStep newStep) { bool inverse = newStep.Inverse ^ step.Inverse; - if (newStep.Kind == RewriteStep::Decompose && inverse) { + if ((newStep.Kind == RewriteStep::Decompose || + newStep.Kind == RewriteStep::DecomposeConcrete) && + inverse) { assert(decomposeCount > 0); --decomposeCount; } @@ -295,11 +304,14 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, newStep.Inverse = inverse; newSteps.push_back(newStep); - if (newStep.Kind == RewriteStep::Decompose && !inverse) { + if ((newStep.Kind == RewriteStep::Decompose || + newStep.Kind == RewriteStep::DecomposeConcrete) && + !inverse) { ++decomposeCount; } }; + // If this rewrite step is inverted, invert the entire path. if (step.Inverse) { for (auto newStep : llvm::reverse(path)) recontextualizeStep(newStep); @@ -308,12 +320,16 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, recontextualizeStep(newStep); } + // Decompose and DecomposeConcrete steps should come in balanced pairs. + assert(decomposeCount == 0); + break; } case RewriteStep::PrefixSubstitutions: case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: + case RewriteStep::DecomposeConcrete: newSteps.push_back(step); break; } diff --git a/lib/AST/RequirementMachine/MinimalConformances.cpp b/lib/AST/RequirementMachine/MinimalConformances.cpp index cc99ef856164d..c69c6b04cfbe4 100644 --- a/lib/AST/RequirementMachine/MinimalConformances.cpp +++ b/lib/AST/RequirementMachine/MinimalConformances.cpp @@ -138,6 +138,7 @@ void RewriteLoop::findProtocolConformanceRules( case RewriteStep::Shift: case RewriteStep::Decompose: case RewriteStep::Relation: + case RewriteStep::DecomposeConcrete: break; } } diff --git a/lib/AST/RequirementMachine/RewriteLoop.cpp b/lib/AST/RequirementMachine/RewriteLoop.cpp index df6adcca8c894..09a98dfc4f3d0 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.cpp +++ b/lib/AST/RequirementMachine/RewriteLoop.cpp @@ -57,6 +57,7 @@ //===----------------------------------------------------------------------===// #include "swift/AST/Type.h" +#include "swift/Basic/Range.h" #include "llvm/Support/raw_ostream.h" #include #include "RewriteSystem.h" @@ -125,6 +126,16 @@ void RewriteStep::dump(llvm::raw_ostream &out, break; } + case DecomposeConcrete: { + evaluator.applyDecomposeConcrete(*this, system); + + out << (Inverse ? "ComposeConcrete(" : "DecomposeConcrete("); + + const auto &difference = system.getTypeDifference(Arg); + + out << difference.LHS << " : " << difference.RHS << ")"; + break; + } } } @@ -444,6 +455,93 @@ RewritePathEvaluator::applyRelation(const RewriteStep &step, return {lhs, rhs, prefix, suffix}; } +void RewritePathEvaluator::applyDecomposeConcrete(const RewriteStep &step, + const RewriteSystem &system) { + assert(step.Kind == RewriteStep::DecomposeConcrete); + + const auto &difference = system.getTypeDifference(step.Arg); + auto bug = [&](StringRef msg) { + llvm::errs() << msg << "\n"; + llvm::errs() << "- StartOffset: " << step.StartOffset << "\n"; + llvm::errs() << "- EndOffset: " << step.EndOffset << "\n"; + llvm::errs() << "- DifferenceID: " << step.Arg << "\n"; + llvm::errs() << "\nType difference:\n"; + difference.dump(llvm::errs()); + llvm::errs() << "\nEvaluator state:\n"; + dump(llvm::errs()); + abort(); + }; + + auto substitutions = difference.LHS.getSubstitutions(); + + auto getReplacementSubstitution = [&](unsigned n) -> MutableTerm { + for (const auto &pair : difference.SameTypes) { + if (pair.first == n) { + // Given a transformation Xn -> Xn', return the term Xn'. + return MutableTerm(pair.second); + } + } + + for (const auto &pair : difference.ConcreteTypes) { + if (pair.first == n) { + // Given a transformation Xn -> [concrete: D], return the + // return Xn.[concrete: D]. + MutableTerm result(substitutions[n]); + result.add(pair.second); + return result; + } + } + + // Otherwise return the original substitution Xn. + return MutableTerm(substitutions[n]); + }; + + if (!step.Inverse) { + auto &term = getCurrentTerm(); + + auto concreteSymbol = *(term.end() - step.EndOffset - 1); + if (concreteSymbol != difference.RHS) + bug("Concrete symbol not equal to expected RHS"); + + MutableTerm newTerm(term.begin(), term.end() - step.EndOffset - 1); + newTerm.add(difference.LHS); + newTerm.append(term.end() - step.EndOffset, term.end()); + term = newTerm; + + for (unsigned n : indices(substitutions)) + Primary.push_back(getReplacementSubstitution(n)); + + } else { + unsigned numSubstitutions = substitutions.size(); + + if (Primary.size() < numSubstitutions + 1) + bug("Not enough terms on the stack"); + + for (unsigned n : indices(substitutions)) { + const auto &otherSubstitution = *(Primary.end() - numSubstitutions + n); + auto expectedSubstitution = getReplacementSubstitution(n); + if (otherSubstitution != expectedSubstitution) { + llvm::errs() << "Got: " << otherSubstitution << "\n"; + llvm::errs() << "Expected: " << expectedSubstitution << "\n"; + bug("Unexpected substitution term on the stack"); + } + } + + Primary.resize(Primary.size() - numSubstitutions); + + auto &term = getCurrentTerm(); + + auto concreteSymbol = *(term.end() - step.EndOffset - 1); + if (concreteSymbol != difference.LHS) + bug("Concrete symbol not equal to expected LHS"); + + MutableTerm newTerm(term.begin(), term.end() - step.EndOffset - 1); + newTerm.add(difference.RHS); + newTerm.append(term.end() - step.EndOffset, term.end()); + term = newTerm; + } +} + void RewritePathEvaluator::apply(const RewriteStep &step, const RewriteSystem &system) { switch (step.Kind) { @@ -466,5 +564,9 @@ void RewritePathEvaluator::apply(const RewriteStep &step, case RewriteStep::Relation: applyRelation(step, system); break; + + case RewriteStep::DecomposeConcrete: + applyDecomposeConcrete(step, system); + break; } } diff --git a/lib/AST/RequirementMachine/RewriteLoop.h b/lib/AST/RequirementMachine/RewriteLoop.h index b574aaac373cc..0efdb9e70ae4e 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.h +++ b/lib/AST/RequirementMachine/RewriteLoop.h @@ -110,7 +110,48 @@ struct RewriteStep { /// /// The Arg field stores the result of calling /// RewriteSystem::recordRelation(). - Relation + Relation, + + /// A generalization of `Decompose` that can replace structural components + /// of the type with concrete types, using a TypeDifference that has been + /// computed previously. + /// + /// The Arg field is a TypeDifference ID, returned from + /// RewriteSystem::registerTypeDifference(). + /// + /// Say the TypeDifference LHS is [concrete: C<...> with ], and + /// say the TypeDifference RHS is [concrete: C'<...> with ]. + /// + /// Note that the LHS and RHS may have a different number of substitutions. + /// + /// If not inverted: the top of the primary stack must be a term ending + /// with the RHS of the TypeDifference: + /// + /// T.[concrete: C'<...> with ] + /// + /// First, the symbol at the end of the term is replaced by the LHS of the + /// TypeDifference: + /// + /// T.[concrete: C<...> with ] + /// + /// Then, each substitution of the LHS is pushed on the primary stack, with + /// the transforms of the TypeDifference applied: + /// + /// - If (n, f(Xn)) appears in TypeDifference::SameTypes, then we push + /// f(Xn). + /// - If (n, [concrete: D]) appears in TypeDifference::ConcreteTypes, then + /// we push Xn.[concrete: D]. + /// - Otherwise, we push Xn. + /// + /// This gives you something like: + /// + /// T.[concrete: C<...> with ] X1 f(X2) X3.[concrete: D] + /// + /// If inverted: the above is performed in reverse, leaving behind the + /// term ending with the TypeDifference RHS at the top of the primary stack: + /// + /// T.[concrete: C'<...> with ] + DecomposeConcrete }; /// The rewrite step kind. @@ -179,6 +220,11 @@ struct RewriteStep { /*arg=*/relationID, inverse); } + static RewriteStep forDecomposeConcrete(unsigned differenceID, bool inverse) { + return RewriteStep(DecomposeConcrete, /*startOffset=*/0, /*endOffset=*/0, + /*arg=*/differenceID, inverse); + } + bool isInContext() const { return StartOffset > 0 || EndOffset > 0; } @@ -371,6 +417,9 @@ struct RewritePathEvaluator { applyRelation(const RewriteStep &step, const RewriteSystem &system); + void applyDecomposeConcrete(const RewriteStep &step, + const RewriteSystem &system); + void dump(llvm::raw_ostream &out) const; }; From 730941b3eabd0b254183ad7857b19674fde2ef82 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 23:25:28 -0500 Subject: [PATCH 09/25] RequirementMachine: Implement PropertyMap::concretelySimplifyLeftHandSideSubstitutions() --- lib/AST/RequirementMachine/PropertyMap.cpp | 171 ++++++++++++++++++ lib/AST/RequirementMachine/PropertyMap.h | 5 + ...l_concrete_substitutions_in_protocol.swift | 34 ++++ test/Generics/unify_superclass_types_2.swift | 5 +- test/Generics/unify_superclass_types_3.swift | 3 +- 5 files changed, 215 insertions(+), 3 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyMap.cpp b/lib/AST/RequirementMachine/PropertyMap.cpp index 2d6e6e3a929a9..f4c821561671c 100644 --- a/lib/AST/RequirementMachine/PropertyMap.cpp +++ b/lib/AST/RequirementMachine/PropertyMap.cpp @@ -401,6 +401,10 @@ PropertyMap::buildPropertyMap(unsigned maxIterations, // the concrete type witnesses in the concrete type's conformance. concretizeNestedTypesFromConcreteParents(); + // Finally, a post-processing pass to reduce substitutions down to + // concrete types. + concretelySimplifyLeftHandSideSubstitutions(); + unsigned addedNewRules = System.getRules().size() - ruleCount; for (unsigned i = ruleCount, e = System.getRules().size(); i < e; ++i) { const auto &newRule = System.getRule(i); @@ -417,6 +421,173 @@ PropertyMap::buildPropertyMap(unsigned maxIterations, return std::make_pair(CompletionResult::Success, addedNewRules); } +/// Similar to RewriteSystem::simplifySubstitutions(), but also replaces type +/// parameters with concrete types and builds a type difference describing +/// the transformation. +/// +/// Returns None if the concrete type symbol cannot be simplified further. +/// +/// Otherwise returns an index which can be passed to +/// RewriteSystem::getTypeDifference(). +Optional +PropertyMap::concretelySimplifySubstitutions(Symbol symbol, + RewritePath *path) const { + assert(symbol.hasSubstitutions()); + + // Fast path if the type is fully concrete. + auto substitutions = symbol.getSubstitutions(); + if (substitutions.empty()) + return None; + + // Save the original rewrite path length so that we can reset if if we don't + // find anything to simplify. + unsigned oldSize = (path ? path->size() : 0); + + if (path) { + // The term is at the top of the primary stack. Push all substitutions onto + // the primary stack. + path->add(RewriteStep::forDecompose(substitutions.size(), + /*inverse=*/false)); + + // Move all substitutions but the first one to the secondary stack. + for (unsigned i = 1; i < substitutions.size(); ++i) + path->add(RewriteStep::forShift(/*inverse=*/false)); + } + + // Simplify and collect substitutions. + llvm::SmallVector, 1> sameTypes; + llvm::SmallVector, 1> concreteTypes; + + for (unsigned index : indices(substitutions)) { + // Move the next substitution from the secondary stack to the primary stack. + if (index != 0 && path) + path->add(RewriteStep::forShift(/*inverse=*/true)); + + auto term = symbol.getSubstitutions()[index]; + MutableTerm mutTerm(term); + + // Note that it's of course possible that the term both requires + // simplification, and the simplified term has a concrete type. + // + // This isn't handled with our current representation of + // TypeDifference, but that should be fine since the caller + // has to iterate until fixed point anyway. + // + // This should be rare in practice. + if (System.simplify(mutTerm, path)) { + // Record a mapping from this substitution to the simplified term. + sameTypes.emplace_back(index, Term::get(mutTerm, Context)); + } else { + auto *props = lookUpProperties(mutTerm); + + if (props && props->ConcreteType) { + // The property map entry might apply to a suffix of the substitution + // term, so prepend the appropriate prefix to its own substitutions. + auto prefix = props->getPrefixAfterStrippingKey(mutTerm); + auto concreteSymbol = + props->ConcreteType->prependPrefixToConcreteSubstitutions( + prefix, Context); + + // Record a mapping from this substitution to the concrete type. + concreteTypes.emplace_back(index, concreteSymbol); + + // If U.V is the substitution term and V is the property map key, + // apply the rewrite step U.(V => V.[concrete: C]) followed by + // prepending the prefix U to each substitution in the concrete type + // symbol if |U| > 0. + if (path) { + path->add(RewriteStep::forRewriteRule(/*startOffset=*/prefix.size(), + /*endOffset=*/0, + /*ruleID=*/*props->ConcreteTypeRule, + /*inverse=*/true)); + + path->add(RewriteStep::forPrefixSubstitutions(/*length=*/prefix.size(), + /*endOffset=*/0, + /*inverse=*/false)); + } + } + } + } + + // If nothing changed, we don't have to build the type difference. + if (sameTypes.empty() && concreteTypes.empty()) { + if (path) { + // The rewrite path should consist of a Decompose, followed by a number + // of Shifts, followed by a Compose. + #ifndef NDEBUG + for (auto iter = path->begin() + oldSize; iter < path->end(); ++iter) { + assert(iter->Kind == RewriteStep::Shift || + iter->Kind == RewriteStep::Decompose); + } + #endif + + path->resize(oldSize); + } + return None; + } + + auto difference = buildTypeDifference(symbol, sameTypes, concreteTypes, + Context); + assert(difference.LHS != difference.RHS); + + unsigned differenceID = System.recordTypeDifference(difference.LHS, + difference.RHS, + difference); + + // All simplified substitutions are now on the primary stack. Collect them to + // produce the new term. + if (path) { + path->add(RewriteStep::forDecomposeConcrete(differenceID, + /*inverse=*/true)); + } + + return differenceID; +} + +void PropertyMap::concretelySimplifyLeftHandSideSubstitutions() const { + for (unsigned ruleID = 0, e = System.getRules().size(); ruleID < e; ++ruleID) { + auto &rule = System.getRule(ruleID); + if (rule.isLHSSimplified() || + rule.isRHSSimplified() || + rule.isSubstitutionSimplified()) + continue; + + auto symbol = rule.getLHS().back(); + if (!symbol.hasSubstitutions()) + continue; + + RewritePath path; + + auto differenceID = concretelySimplifySubstitutions(symbol, &path); + if (!differenceID) + continue; + + rule.markSubstitutionSimplified(); + + auto difference = System.getTypeDifference(*differenceID); + assert(difference.LHS == symbol); + + // If the original rule is (T.[concrete: C] => T) and [concrete: C'] is + // the simplified symbol, then difference.LHS == [concrete: C] and + // difference.RHS == [concrete: C'], and the rewrite path we just + // built takes T.[concrete: C] to T.[concrete: C']. + // + // We want a path from T.[concrete: C'] to T, so invert the path to get + // a path from T.[concrete: C'] to T.[concrete: C], and add a final step + // applying the original rule (T.[concrete: C] => T). + path.invert(); + path.add(RewriteStep::forRewriteRule(/*startOffset=*/0, + /*endOffset=*/0, + /*ruleID=*/ruleID, + /*inverted=*/false)); + MutableTerm rhs(rule.getRHS()); + MutableTerm lhs(rhs); + lhs.add(difference.RHS); + + System.addRule(lhs, rhs, &path); + } +} + void PropertyMap::dump(llvm::raw_ostream &out) const { out << "Property map: {\n"; for (const auto &props : Entries) { diff --git a/lib/AST/RequirementMachine/PropertyMap.h b/lib/AST/RequirementMachine/PropertyMap.h index 9cd6a97449848..37561173b9d23 100644 --- a/lib/AST/RequirementMachine/PropertyMap.h +++ b/lib/AST/RequirementMachine/PropertyMap.h @@ -286,6 +286,11 @@ class PropertyMap { RequirementKind requirementKind, Symbol concreteConformanceSymbol) const; + Optional concretelySimplifySubstitutions(Symbol symbol, + RewritePath *path) const; + + void concretelySimplifyLeftHandSideSubstitutions() const; + void verify() const; }; diff --git a/test/Generics/canonical_concrete_substitutions_in_protocol.swift b/test/Generics/canonical_concrete_substitutions_in_protocol.swift index 13c5b3103b25c..17f4d0f9f5054 100644 --- a/test/Generics/canonical_concrete_substitutions_in_protocol.swift +++ b/test/Generics/canonical_concrete_substitutions_in_protocol.swift @@ -21,3 +21,37 @@ protocol R { associatedtype A associatedtype C: QQ where C.X == G } + +// Make sure substitutions which are themselves concrete simplify recursively. + +// CHECK-LABEL: canonical_concrete_substitutions_in_protocol.(file).P1@ +// CHECK-NEXT: Requirement signature: > + +protocol P1 { + associatedtype T where T == Int + associatedtype U where U == G +} + +// CHECK-LABEL: canonical_concrete_substitutions_in_protocol.(file).P2@ +// CHECK-NEXT: Requirement signature: > + +protocol P2 { + associatedtype U where U == G + associatedtype T where T == Int +} + +// CHECK-LABEL: canonical_concrete_substitutions_in_protocol.(file).P3@ +// CHECK-NEXT: Requirement signature: , Self.[P3]U == Int> + +protocol P3 { + associatedtype T where T == G + associatedtype U where U == Int +} + +// CHECK-LABEL: canonical_concrete_substitutions_in_protocol.(file).P4@ +// CHECK-NEXT: Requirement signature: , Self.[P4]U == Int> + +protocol P4 { + associatedtype U where U == Int + associatedtype T where T == G +} diff --git a/test/Generics/unify_superclass_types_2.swift b/test/Generics/unify_superclass_types_2.swift index 0d3dfaf1267b5..f9b8c0a27f836 100644 --- a/test/Generics/unify_superclass_types_2.swift +++ b/test/Generics/unify_superclass_types_2.swift @@ -34,8 +34,9 @@ func unifySuperclassTest(_: T) { // CHECK: - τ_0_0.[P2:A2].[concrete: Int] => τ_0_0.[P2:A2] // CHECK-NEXT: - τ_0_0.[P1:A1].[concrete: String] => τ_0_0.[P1:A1] // CHECK-NEXT: - τ_0_0.[P2:B2] => τ_0_0.[P1:B1] +// CHECK-NEXT: - τ_0_0.[P1:X].[superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P1:B1]>] => τ_0_0.[P1:X] // CHECK-NEXT: - τ_0_0.B2 => τ_0_0.[P1:B1] -// CHECK: - τ_0_0.[P1:X].[superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P1:B1]>] => τ_0_0.[P1:X] +// CHECK: - τ_0_0.[P1:X].[superclass: Generic with <τ_0_0.[P1:B1]>] => τ_0_0.[P1:X] // CHECK: } // CHECK: Property map: { // CHECK-NEXT: [P1] => { conforms_to: [P1] } @@ -45,5 +46,5 @@ func unifySuperclassTest(_: T) { // CHECK-NEXT: τ_0_0 => { conforms_to: [P1 P2] } // CHECK-NEXT: τ_0_0.[P2:A2] => { concrete_type: [concrete: Int] } // CHECK-NEXT: τ_0_0.[P1:A1] => { concrete_type: [concrete: String] } -// CHECK-NEXT: τ_0_0.[P1:X] => { layout: _NativeClass superclass: [superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P1:B1]>] } +// CHECK-NEXT: τ_0_0.[P1:X] => { layout: _NativeClass superclass: [superclass: Generic with <τ_0_0.[P1:B1]>] } // CHECK-NEXT: } diff --git a/test/Generics/unify_superclass_types_3.swift b/test/Generics/unify_superclass_types_3.swift index 984d3416a247c..5ad1a7edf2d76 100644 --- a/test/Generics/unify_superclass_types_3.swift +++ b/test/Generics/unify_superclass_types_3.swift @@ -38,8 +38,9 @@ func unifySuperclassTest(_: T) { // CHECK: - τ_0_0.[P2:A2].[concrete: Int] => τ_0_0.[P2:A2] // CHECK-NEXT: - τ_0_0.[P1:A1].[concrete: String] => τ_0_0.[P1:A1] // CHECK-NEXT: - τ_0_0.[P2:B2] => τ_0_0.[P1:B1] -// CHECK-NEXT: - τ_0_0.B2 => τ_0_0.[P1:B1] // CHECK-NEXT: - τ_0_0.[P1:X].[superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P1:B1]>] => τ_0_0.[P1:X] +// CHECK-NEXT: - τ_0_0.B2 => τ_0_0.[P1:B1] +// CHECK-NEXT: - τ_0_0.[P1:X].[superclass: Generic with <τ_0_0.[P1:B1]>] => τ_0_0.[P1:X] // CHECK-NEXT: } // CHECK: Property map: { // CHECK-NEXT: [P1] => { conforms_to: [P1] } From 634ca55764a4cc066c5c7c85c5de43d006bdeac3 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Feb 2022 21:37:40 -0500 Subject: [PATCH 10/25] RequirementMachine: Rework completion limits a bit - Rename StepLimit to MaxRuleCount, DepthLimit to MaxRuleLength - Rename command line flags to -requirement-machine-max-rule-{count,length}= - Check limits outside of PropertyMap::buildPropertyMap() - Simplify the logic in RequirementMachine::computeCompletion() --- include/swift/AST/DiagnosticsSema.def | 2 +- include/swift/Basic/LangOptions.h | 6 +- include/swift/Option/FrontendOptions.td | 8 +- lib/AST/RequirementMachine/KnuthBendix.cpp | 33 +++---- lib/AST/RequirementMachine/PropertyMap.cpp | 33 ++----- lib/AST/RequirementMachine/PropertyMap.h | 4 +- .../RequirementMachine/RequirementMachine.cpp | 85 +++++++++++-------- .../RequirementMachine/RequirementMachine.h | 6 +- .../RequirementMachineRequests.cpp | 4 +- lib/AST/RequirementMachine/RewriteSystem.h | 20 ++--- lib/Frontend/CompilerInvocation.cpp | 8 +- test/Generics/non_confluent.swift | 10 +-- 12 files changed, 101 insertions(+), 118 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 5dbaf42caf6e2..bde55fef7c9aa 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2559,7 +2559,7 @@ WARNING(associated_type_override_typealias,none, ERROR(requirement_machine_completion_failed,none, "cannot build rewrite system for %select{generic signature|protocol}0; " - "%select{step|depth}1 limit exceeded", + "%select{%error|rule count|rule length}1 limit exceeded", (unsigned, unsigned)) ERROR(associated_type_objc,none, diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index ab8a11ac38463..96a458fbc1261 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -510,13 +510,13 @@ namespace swift { /// Enables fine-grained debug output from the requirement machine. std::string DebugRequirementMachine; - /// Maximum iteration count for requirement machine Knuth-Bendix completion + /// Maximum rule count for requirement machine Knuth-Bendix completion /// algorithm. - unsigned RequirementMachineStepLimit = 4000; + unsigned RequirementMachineMaxRuleCount = 4000; /// Maximum term length for requirement machine Knuth-Bendix completion /// algorithm. - unsigned RequirementMachineDepthLimit = 12; + unsigned RequirementMachineMaxRuleLength = 12; /// Enable the new experimental protocol requirement signature minimization /// algorithm. diff --git a/include/swift/Option/FrontendOptions.td b/include/swift/Option/FrontendOptions.td index b241562bb2ceb..907a35be0a98b 100644 --- a/include/swift/Option/FrontendOptions.td +++ b/include/swift/Option/FrontendOptions.td @@ -343,13 +343,13 @@ def analyze_requirement_machine : Flag<["-"], "analyze-requirement-machine">, Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, HelpText<"Print out requirement machine statistics at the end of the compilation job">; -def requirement_machine_step_limit : Separate<["-"], "requirement-machine-step-limit">, +def requirement_machine_max_rule_count : Joined<["-"], "requirement-machine-max-rule-count=">, Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, - HelpText<"Set the maximum steps before we give up on confluent completion">; + HelpText<"Set the maximum number of rules before giving up">; -def requirement_machine_depth_limit : Separate<["-"], "requirement-machine-depth-limit">, +def requirement_machine_max_rule_length : Joined<["-"], "requirement-machine-max-rule-length=">, Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, - HelpText<"Set the maximum depth before we give up on confluent completion">; + HelpText<"Set the maximum rule length before giving up">; def disable_requirement_machine_merged_associated_types : Flag<["-"], "disable-requirement-machine-merged-associated-types">, Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, diff --git a/lib/AST/RequirementMachine/KnuthBendix.cpp b/lib/AST/RequirementMachine/KnuthBendix.cpp index 60312f416e311..147e975c3bfea 100644 --- a/lib/AST/RequirementMachine/KnuthBendix.cpp +++ b/lib/AST/RequirementMachine/KnuthBendix.cpp @@ -486,20 +486,19 @@ RewriteSystem::computeCriticalPair(ArrayRef::const_iterator from, return true; } -/// Computes the confluent completion using the Knuth-Bendix algorithm. +/// Computes the confluent completion using the Knuth-Bendix algorithm and +/// returns a status code. /// -/// Returns a pair consisting of a status and number of iterations executed. +/// The status is CompletionResult::MaxRuleCount if we add more than +/// \p maxRuleCount rules. /// -/// The status is CompletionResult::MaxIterations if we exceed \p maxIterations -/// iterations. -/// -/// The status is CompletionResult::MaxDepth if we produce a rewrite rule whose -/// left hand side has a length exceeding \p maxDepth. +/// The status is CompletionResult::MaxRuleLength if we produce a rewrite rule +/// whose left hand side has a length exceeding \p maxRuleLength. /// /// Otherwise, the status is CompletionResult::Success. -std::pair -RewriteSystem::computeConfluentCompletion(unsigned maxIterations, - unsigned maxDepth) { +CompletionResult +RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount, + unsigned maxRuleLength) { assert(Initialized); assert(!Minimized); @@ -507,8 +506,6 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations, // adding new rules in the property map's concrete type unification procedure. Complete = 1; - unsigned steps = 0; - bool again = false; std::vector resolvedCriticalPairs; @@ -605,18 +602,16 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations, again = false; for (const auto &pair : resolvedCriticalPairs) { // Check if we've already done too much work. - if (Rules.size() > maxIterations) - return std::make_pair(CompletionResult::MaxIterations, steps); + if (Rules.size() > maxRuleCount) + return CompletionResult::MaxRuleCount; if (!addRule(pair.LHS, pair.RHS, &pair.Path)) continue; // Check if the new rule is too long. - if (Rules.back().getDepth() > maxDepth) - return std::make_pair(CompletionResult::MaxDepth, steps); + if (Rules.back().getDepth() > maxRuleLength) + return CompletionResult::MaxRuleLength; - // Only count a 'step' once we add a new rule. - ++steps; again = true; } @@ -640,5 +635,5 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations, assert(MergedAssociatedTypes.empty() && "Should have processed all merge candidates"); - return std::make_pair(CompletionResult::Success, steps); + return CompletionResult::Success; } diff --git a/lib/AST/RequirementMachine/PropertyMap.cpp b/lib/AST/RequirementMachine/PropertyMap.cpp index f4c821561671c..df351a83a0231 100644 --- a/lib/AST/RequirementMachine/PropertyMap.cpp +++ b/lib/AST/RequirementMachine/PropertyMap.cpp @@ -323,18 +323,11 @@ void PropertyMap::clear() { /// Build the property map from all rules of the form T.[p] => T, where /// [p] is a property symbol. /// -/// Returns a pair consisting of a status and number of iterations executed. -/// -/// The status is CompletionResult::MaxIterations if we exceed \p maxIterations -/// iterations. -/// -/// The status is CompletionResult::MaxDepth if we produce a rewrite rule whose -/// left hand side has a length exceeding \p maxDepth. -/// -/// Otherwise, the status is CompletionResult::Success. -std::pair -PropertyMap::buildPropertyMap(unsigned maxIterations, - unsigned maxDepth) { +/// Also performs property unification, nested type concretization and +/// concrete simplification. These phases can add new rules; if new rules +/// were added, the the caller must run another round of Knuth-Bendix +/// completion, and rebuild the property map again. +void PropertyMap::buildPropertyMap() { if (System.getDebugOptions().contains(DebugFlags::PropertyMap)) { llvm::dbgs() << "-------------------------\n"; llvm::dbgs() << "- Building property map -\n"; @@ -382,10 +375,6 @@ PropertyMap::buildPropertyMap(unsigned maxIterations, properties[length].push_back({rhs, *property, ruleID}); } - // Merging multiple superclass or concrete type rules can induce new rules - // to unify concrete type constructor arguments. - unsigned ruleCount = System.getRules().size(); - for (const auto &bucket : properties) { for (auto property : bucket) { addProperty(property.key, property.symbol, @@ -405,20 +394,8 @@ PropertyMap::buildPropertyMap(unsigned maxIterations, // concrete types. concretelySimplifyLeftHandSideSubstitutions(); - unsigned addedNewRules = System.getRules().size() - ruleCount; - for (unsigned i = ruleCount, e = System.getRules().size(); i < e; ++i) { - const auto &newRule = System.getRule(i); - if (newRule.getDepth() > maxDepth) - return std::make_pair(CompletionResult::MaxDepth, addedNewRules); - } - // Check invariants of the constructed property map. verify(); - - if (System.getRules().size() > maxIterations) - return std::make_pair(CompletionResult::MaxIterations, addedNewRules); - - return std::make_pair(CompletionResult::Success, addedNewRules); } /// Similar to RewriteSystem::simplifySubstitutions(), but also replaces type diff --git a/lib/AST/RequirementMachine/PropertyMap.h b/lib/AST/RequirementMachine/PropertyMap.h index 37561173b9d23..ea346953cbc8f 100644 --- a/lib/AST/RequirementMachine/PropertyMap.h +++ b/lib/AST/RequirementMachine/PropertyMap.h @@ -206,9 +206,7 @@ class PropertyMap { std::reverse_iterator end) const; PropertyBag *lookUpProperties(const MutableTerm &key) const; - std::pair - buildPropertyMap(unsigned maxIterations, - unsigned maxDepth); + void buildPropertyMap(); void dump(llvm::raw_ostream &out) const; diff --git a/lib/AST/RequirementMachine/RequirementMachine.cpp b/lib/AST/RequirementMachine/RequirementMachine.cpp index 3c1689933fdb9..10c131567284b 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.cpp +++ b/lib/AST/RequirementMachine/RequirementMachine.cpp @@ -25,8 +25,8 @@ RequirementMachine::RequirementMachine(RewriteContext &ctx) : Context(ctx), System(ctx), Map(System) { auto &langOpts = ctx.getASTContext().LangOpts; Dump = langOpts.DumpRequirementMachine; - RequirementMachineStepLimit = langOpts.RequirementMachineStepLimit; - RequirementMachineDepthLimit = langOpts.RequirementMachineDepthLimit; + MaxRuleCount = langOpts.RequirementMachineMaxRuleCount; + MaxRuleLength = langOpts.RequirementMachineMaxRuleLength; Stats = ctx.getASTContext().Stats; if (Stats) @@ -41,13 +41,13 @@ static void checkCompletionResult(const RequirementMachine &machine, case CompletionResult::Success: break; - case CompletionResult::MaxIterations: - llvm::errs() << "Rewrite system exceeds maximum completion step count\n"; + case CompletionResult::MaxRuleCount: + llvm::errs() << "Rewrite system exceeded maximum rule count\n"; machine.dump(llvm::errs()); abort(); - case CompletionResult::MaxDepth: - llvm::errs() << "Rewrite system exceeds maximum completion depth\n"; + case CompletionResult::MaxRuleLength: + llvm::errs() << "Rewrite system exceeded rule length limit\n"; machine.dump(llvm::errs()); abort(); } @@ -224,43 +224,56 @@ RequirementMachine::initWithWrittenRequirements( CompletionResult RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { while (true) { - // First, run the Knuth-Bendix algorithm to resolve overlapping rules. - auto result = System.computeConfluentCompletion( - RequirementMachineStepLimit, - RequirementMachineDepthLimit); - - if (Stats) { - Stats->getFrontendCounters() - .NumRequirementMachineCompletionSteps += result.second; - } + { + unsigned ruleCount = System.getRules().size(); + + // First, run the Knuth-Bendix algorithm to resolve overlapping rules. + auto result = System.computeConfluentCompletion(MaxRuleCount, MaxRuleLength); - // Check for failure. - if (result.first != CompletionResult::Success) - return result.first; + unsigned rulesAdded = (System.getRules().size() - ruleCount); - // Check invariants. - System.verifyRewriteRules(policy); + if (Stats) { + Stats->getFrontendCounters() + .NumRequirementMachineCompletionSteps += rulesAdded; + } - // Build the property map, which also performs concrete term - // unification; if this added any new rules, run the completion - // procedure again. - result = Map.buildPropertyMap( - RequirementMachineStepLimit, - RequirementMachineDepthLimit); + // Check for failure. + if (result != CompletionResult::Success) + return result; - if (Stats) { - Stats->getFrontendCounters() - .NumRequirementMachineUnifiedConcreteTerms += result.second; + // Check invariants. + System.verifyRewriteRules(policy); } - // Check for failure. - if (result.first != CompletionResult::Success) - return result.first; + { + unsigned ruleCount = System.getRules().size(); + + // Build the property map, which also performs concrete term + // unification; if this added any new rules, run the completion + // procedure again. + Map.buildPropertyMap(); + + unsigned rulesAdded = (System.getRules().size() - ruleCount); - // If buildPropertyMap() added new rules, we run another round of - // Knuth-Bendix, and build the property map again. - if (result.second == 0) - break; + if (Stats) { + Stats->getFrontendCounters() + .NumRequirementMachineUnifiedConcreteTerms += rulesAdded; + } + + // Check new rules added by the property map against configured limits. + for (unsigned i = 0; i < rulesAdded; ++i) { + const auto &newRule = System.getRule(ruleCount + i); + if (newRule.getDepth() > MaxRuleLength) + return CompletionResult::MaxRuleLength; + } + + if (System.getRules().size() > MaxRuleCount) + return CompletionResult::MaxRuleCount; + + // If buildPropertyMap() didn't add any new rules, we are done. + if (rulesAdded == 0) + break; + } } if (Dump) { diff --git a/lib/AST/RequirementMachine/RequirementMachine.h b/lib/AST/RequirementMachine/RequirementMachine.h index f400d84905240..e8ce67d8a4387 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.h +++ b/lib/AST/RequirementMachine/RequirementMachine.h @@ -61,8 +61,10 @@ class RequirementMachine final { bool Dump = false; bool Complete = false; - unsigned RequirementMachineStepLimit; - unsigned RequirementMachineDepthLimit; + + /// Parameters to prevent runaway completion and property map construction. + unsigned MaxRuleCount; + unsigned MaxRuleLength; UnifiedStatsReporter *Stats; diff --git a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp index 5411a61412e44..e364c268c0551 100644 --- a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp +++ b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp @@ -266,7 +266,7 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator, ctx.Diags.diagnose(otherProto->getLoc(), diag::requirement_machine_completion_failed, /*protocol=*/1, - status == CompletionResult::MaxIterations ? 0 : 1); + unsigned(status)); if (otherProto != proto) { ctx.evaluator.cacheOutput( @@ -504,7 +504,7 @@ InferredGenericSignatureRequestRQM::evaluate( ctx.Diags.diagnose(loc, diag::requirement_machine_completion_failed, /*protocol=*/0, - status == CompletionResult::MaxIterations ? 0 : 1); + unsigned(status)); auto result = GenericSignature::get(genericParams, {}); return GenericSignatureWithError(result, /*hadError=*/true); diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 3f21ab6d17230..2ccd9136bbf5f 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -192,18 +192,17 @@ class Rule final { } }; -/// Result type for RewriteSystem::computeConfluentCompletion() and -/// PropertyMap::buildPropertyMap(). +/// Result type for RequirementMachine::computeCompletion(). enum class CompletionResult { - /// Confluent completion was computed successfully. + /// Completion was successful. Success, - /// Maximum number of iterations reached. - MaxIterations, + /// Maximum number of rules exceeded. + MaxRuleCount, + + /// Maximum rule length exceeded. + MaxRuleLength, - /// Completion produced a rewrite rule whose left hand side has a length - /// exceeding the limit. - MaxDepth }; /// A term rewrite system for working with types in a generic signature. @@ -322,9 +321,8 @@ class RewriteSystem final { /// Pairs of rules which have already been checked for overlap. llvm::DenseSet> CheckedOverlaps; - std::pair - computeConfluentCompletion(unsigned maxIterations, - unsigned maxDepth); + CompletionResult computeConfluentCompletion(unsigned maxRuleCount, + unsigned maxRuleLength); void simplifyLeftHandSides(); diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 57752c09f4c0b..8ced5974291a9 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -928,25 +928,25 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, if (const Arg *A = Args.getLastArg(OPT_debug_requirement_machine)) Opts.DebugRequirementMachine = A->getValue(); - if (const Arg *A = Args.getLastArg(OPT_requirement_machine_step_limit)) { + if (const Arg *A = Args.getLastArg(OPT_requirement_machine_max_rule_count)) { unsigned limit; if (StringRef(A->getValue()).getAsInteger(10, limit)) { Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value, A->getAsString(Args), A->getValue()); HadError = true; } else { - Opts.RequirementMachineStepLimit = limit; + Opts.RequirementMachineMaxRuleCount = limit; } } - if (const Arg *A = Args.getLastArg(OPT_requirement_machine_depth_limit)) { + if (const Arg *A = Args.getLastArg(OPT_requirement_machine_max_rule_length)) { unsigned limit; if (StringRef(A->getValue()).getAsInteger(10, limit)) { Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value, A->getAsString(Args), A->getValue()); HadError = true; } else { - Opts.RequirementMachineDepthLimit = limit; + Opts.RequirementMachineMaxRuleLength = limit; } } diff --git a/test/Generics/non_confluent.swift b/test/Generics/non_confluent.swift index 7cdf60026adbf..89def46e8c3f2 100644 --- a/test/Generics/non_confluent.swift +++ b/test/Generics/non_confluent.swift @@ -1,12 +1,12 @@ // RUN: %target-typecheck-verify-swift -requirement-machine-protocol-signatures=on -requirement-machine-inferred-signatures=on -protocol ABA // expected-error {{cannot build rewrite system for protocol; depth limit exceeded}} +protocol ABA // expected-error {{cannot build rewrite system for protocol; rule length limit exceeded}} where A.B == A.B.A { // expected-error *{{is not a member type}} associatedtype A : ABA associatedtype B : ABA } -protocol Undecidable // expected-error {{cannot build rewrite system for protocol; depth limit exceeded}} +protocol Undecidable // expected-error {{cannot build rewrite system for protocol; rule length limit exceeded}} where A.C == C.A, // expected-error *{{is not a member type}} A.D == D.A, // expected-error *{{is not a member type}} B.C == C.B, // expected-error *{{is not a member type}} @@ -30,17 +30,17 @@ protocol P2 { } func foo(_: T) {} -// expected-error@-1 {{cannot build rewrite system for generic signature; depth limit exceeded}} +// expected-error@-1 {{cannot build rewrite system for generic signature; rule length limit exceeded}} extension P1 where Self : P2 {} -// expected-error@-1 {{cannot build rewrite system for generic signature; depth limit exceeded}} +// expected-error@-1 {{cannot build rewrite system for generic signature; rule length limit exceeded}} struct S : P1 { typealias T = S> } protocol P3 { -// expected-error@-1 {{cannot build rewrite system for protocol; depth limit exceeded}} +// expected-error@-1 {{cannot build rewrite system for protocol; rule length limit exceeded}} associatedtype T : P1 where T == S // expected-error@-1 {{type 'Self.U' does not conform to protocol 'P1'}} associatedtype U : P1 From 0060592b851970943fe27505ebb117333b8bcac5 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Feb 2022 21:46:20 -0500 Subject: [PATCH 11/25] RequirementMachine: Add concrete nesting depth check Configured with -requirement-machine-max-concrete-nesting= frontend flag. --- include/swift/AST/DiagnosticsSema.def | 2 +- include/swift/Basic/LangOptions.h | 4 +++ include/swift/Option/FrontendOptions.td | 4 +++ .../RequirementMachine/RequirementMachine.cpp | 8 +++++ .../RequirementMachine/RequirementMachine.h | 1 + lib/AST/RequirementMachine/RewriteSystem.cpp | 34 +++++++++++++++++++ lib/AST/RequirementMachine/RewriteSystem.h | 4 +++ lib/Frontend/CompilerInvocation.cpp | 11 ++++++ test/Generics/infinite_concrete_type.swift | 20 +++++++++++ 9 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 test/Generics/infinite_concrete_type.swift diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index bde55fef7c9aa..fee7fd58b9307 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2559,7 +2559,7 @@ WARNING(associated_type_override_typealias,none, ERROR(requirement_machine_completion_failed,none, "cannot build rewrite system for %select{generic signature|protocol}0; " - "%select{%error|rule count|rule length}1 limit exceeded", + "%select{%error|rule count|rule length|concrete nesting}1 limit exceeded", (unsigned, unsigned)) ERROR(associated_type_objc,none, diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 96a458fbc1261..bb3300b8bd64b 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -518,6 +518,10 @@ namespace swift { /// algorithm. unsigned RequirementMachineMaxRuleLength = 12; + /// Maximum concrete type nesting depth for requirement machine property map + /// algorithm. + unsigned RequirementMachineMaxConcreteNesting = 30; + /// Enable the new experimental protocol requirement signature minimization /// algorithm. RequirementMachineMode RequirementMachineProtocolSignatures = diff --git a/include/swift/Option/FrontendOptions.td b/include/swift/Option/FrontendOptions.td index 907a35be0a98b..66aeff71de455 100644 --- a/include/swift/Option/FrontendOptions.td +++ b/include/swift/Option/FrontendOptions.td @@ -351,6 +351,10 @@ def requirement_machine_max_rule_length : Joined<["-"], "requirement-machine-max Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, HelpText<"Set the maximum rule length before giving up">; +def requirement_machine_max_concrete_nesting : Joined<["-"], "requirement-machine-max-concrete-nesting=">, + Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, + HelpText<"Set the maximum concrete type nesting depth before giving up">; + def disable_requirement_machine_merged_associated_types : Flag<["-"], "disable-requirement-machine-merged-associated-types">, Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>, HelpText<"Disable merged associated types">; diff --git a/lib/AST/RequirementMachine/RequirementMachine.cpp b/lib/AST/RequirementMachine/RequirementMachine.cpp index 10c131567284b..e7edab1d05fde 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.cpp +++ b/lib/AST/RequirementMachine/RequirementMachine.cpp @@ -27,6 +27,7 @@ RequirementMachine::RequirementMachine(RewriteContext &ctx) Dump = langOpts.DumpRequirementMachine; MaxRuleCount = langOpts.RequirementMachineMaxRuleCount; MaxRuleLength = langOpts.RequirementMachineMaxRuleLength; + MaxConcreteNesting = langOpts.RequirementMachineMaxConcreteNesting; Stats = ctx.getASTContext().Stats; if (Stats) @@ -50,6 +51,11 @@ static void checkCompletionResult(const RequirementMachine &machine, llvm::errs() << "Rewrite system exceeded rule length limit\n"; machine.dump(llvm::errs()); abort(); + + case CompletionResult::MaxConcreteNesting: + llvm::errs() << "Rewrite system exceeded concrete type nesting depth limit\n"; + machine.dump(llvm::errs()); + abort(); } } @@ -265,6 +271,8 @@ RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { const auto &newRule = System.getRule(ruleCount + i); if (newRule.getDepth() > MaxRuleLength) return CompletionResult::MaxRuleLength; + if (newRule.getNesting() > MaxConcreteNesting) + return CompletionResult::MaxConcreteNesting; } if (System.getRules().size() > MaxRuleCount) diff --git a/lib/AST/RequirementMachine/RequirementMachine.h b/lib/AST/RequirementMachine/RequirementMachine.h index e8ce67d8a4387..9dcc1a03c97dc 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.h +++ b/lib/AST/RequirementMachine/RequirementMachine.h @@ -65,6 +65,7 @@ class RequirementMachine final { /// Parameters to prevent runaway completion and property map construction. unsigned MaxRuleCount; unsigned MaxRuleLength; + unsigned MaxConcreteNesting; UnifiedStatsReporter *Stats; diff --git a/lib/AST/RequirementMachine/RewriteSystem.cpp b/lib/AST/RequirementMachine/RewriteSystem.cpp index dfd6036ed95b8..d1919a4ecc22a 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.cpp +++ b/lib/AST/RequirementMachine/RewriteSystem.cpp @@ -12,6 +12,7 @@ #include "swift/AST/Decl.h" #include "swift/AST/Types.h" +#include "swift/AST/TypeWalker.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/Support/raw_ostream.h" #include @@ -132,6 +133,39 @@ unsigned Rule::getDepth() const { return result; } +/// Returns the nesting depth of the concrete symbol at the end of the +/// left hand side, or 0 if there isn't one. +unsigned Rule::getNesting() const { + if (LHS.back().hasSubstitutions()) { + auto type = LHS.back().getConcreteType(); + + struct Walker : TypeWalker { + unsigned Nesting = 0; + unsigned MaxNesting = 0; + + Action walkToTypePre(Type ty) override { + ++Nesting; + MaxNesting = std::max(Nesting, MaxNesting); + + return Action::Continue; + } + + Action walkToTypePost(Type ty) override { + --Nesting; + + return Action::Continue; + } + }; + + Walker walker; + type.walk(walker); + + return walker.MaxNesting; + } + + return 0; +} + /// Linear order on rules; compares LHS followed by RHS. int Rule::compare(const Rule &other, RewriteContext &ctx) const { int compare = LHS.compare(other.LHS, ctx); diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 2ccd9136bbf5f..feb980f5df6e7 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -181,6 +181,8 @@ class Rule final { unsigned getDepth() const; + unsigned getNesting() const; + int compare(const Rule &other, RewriteContext &ctx) const; void dump(llvm::raw_ostream &out) const; @@ -203,6 +205,8 @@ enum class CompletionResult { /// Maximum rule length exceeded. MaxRuleLength, + /// Maximum concrete type nesting depth exceeded. + MaxConcreteNesting }; /// A term rewrite system for working with types in a generic signature. diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 8ced5974291a9..5bd4e19952209 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -950,6 +950,17 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args, } } + if (const Arg *A = Args.getLastArg(OPT_requirement_machine_max_concrete_nesting)) { + unsigned limit; + if (StringRef(A->getValue()).getAsInteger(10, limit)) { + Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value, + A->getAsString(Args), A->getValue()); + HadError = true; + } else { + Opts.RequirementMachineMaxConcreteNesting = limit; + } + } + return HadError || UnsupportedOS || UnsupportedArch; } diff --git a/test/Generics/infinite_concrete_type.swift b/test/Generics/infinite_concrete_type.swift new file mode 100644 index 0000000000000..95f3a2ee9bb14 --- /dev/null +++ b/test/Generics/infinite_concrete_type.swift @@ -0,0 +1,20 @@ +// RUN: %target-typecheck-verify-swift -requirement-machine-protocol-signatures=on -requirement-machine-inferred-signatures=on + +class G {} + +protocol P1 { // expected-error {{cannot build rewrite system for protocol; concrete nesting limit exceeded}} + associatedtype A where A == G + associatedtype B where B == G +} + +// The GenericSignatureBuilder rejected this protocol, but there's no real +// reason to do that. +protocol P2 { + associatedtype A where A : G + associatedtype B where B : G +} + +func useP2(_: T) { + _ = T.A.self + _ = T.B.self +} From b0dd114fdd65932f30c26f883b571fe66ec477d4 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 2 Feb 2022 23:25:37 -0500 Subject: [PATCH 12/25] RequirementMachine: Add a FIXME comment --- lib/AST/RequirementMachine/PropertyUnification.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index ba12cf386cc91..f2224b5c3f042 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -518,6 +518,9 @@ void PropertyMap::addConcreteTypeProperty( assert(simplified); (void) simplified; + // FIXME: This is unsound! While 'key' was canonical at the time we + // started property map construction, we might have added other rules + // since then that made it non-canonical. assert(path.size() == 1); assert(path.begin()->Kind == RewriteStep::Rule); From 37be2d5dd700f87ce845a486c3fadfdaeb7a3455 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Feb 2022 23:16:45 -0500 Subject: [PATCH 13/25] RequirementMachine: Emit a diagnostic note with the offending rewrite rule if completion failed This surfaces an implementation detail, but it might be better than nothing. --- include/swift/AST/DiagnosticsSema.def | 2 + lib/AST/RequirementMachine/KnuthBendix.cpp | 15 ++++-- .../RequirementMachine/RequirementMachine.cpp | 46 +++++++++++++------ .../RequirementMachine/RequirementMachine.h | 11 +++-- .../RequirementMachineRequests.cpp | 18 ++++++-- lib/AST/RequirementMachine/RewriteSystem.h | 5 +- test/Generics/infinite_concrete_type.swift | 1 + test/Generics/non_confluent.swift | 6 +++ 8 files changed, 77 insertions(+), 27 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index fee7fd58b9307..f024aee8fc4f9 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2561,6 +2561,8 @@ ERROR(requirement_machine_completion_failed,none, "cannot build rewrite system for %select{generic signature|protocol}0; " "%select{%error|rule count|rule length|concrete nesting}1 limit exceeded", (unsigned, unsigned)) +NOTE(requirement_machine_completion_rule,none, + "failed rewrite rule is %0", (StringRef)) ERROR(associated_type_objc,none, "associated type %0 cannot be declared inside '@objc' protocol %1", diff --git a/lib/AST/RequirementMachine/KnuthBendix.cpp b/lib/AST/RequirementMachine/KnuthBendix.cpp index 147e975c3bfea..59127821e907c 100644 --- a/lib/AST/RequirementMachine/KnuthBendix.cpp +++ b/lib/AST/RequirementMachine/KnuthBendix.cpp @@ -489,14 +489,19 @@ RewriteSystem::computeCriticalPair(ArrayRef::const_iterator from, /// Computes the confluent completion using the Knuth-Bendix algorithm and /// returns a status code. /// +/// The first element of the pair is a status. +/// /// The status is CompletionResult::MaxRuleCount if we add more than /// \p maxRuleCount rules. /// /// The status is CompletionResult::MaxRuleLength if we produce a rewrite rule /// whose left hand side has a length exceeding \p maxRuleLength. /// -/// Otherwise, the status is CompletionResult::Success. -CompletionResult +/// In the above two cases, the second element of the pair is a rule ID. +/// +/// Otherwise, the status is CompletionResult::Success and the second element +/// is zero. +std::pair RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount, unsigned maxRuleLength) { assert(Initialized); @@ -603,14 +608,14 @@ RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount, for (const auto &pair : resolvedCriticalPairs) { // Check if we've already done too much work. if (Rules.size() > maxRuleCount) - return CompletionResult::MaxRuleCount; + return std::make_pair(CompletionResult::MaxRuleCount, Rules.size() - 1); if (!addRule(pair.LHS, pair.RHS, &pair.Path)) continue; // Check if the new rule is too long. if (Rules.back().getDepth() > maxRuleLength) - return CompletionResult::MaxRuleLength; + return std::make_pair(CompletionResult::MaxRuleLength, Rules.size() - 1); again = true; } @@ -635,5 +640,5 @@ RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount, assert(MergedAssociatedTypes.empty() && "Should have processed all merge candidates"); - return CompletionResult::Success; + return std::make_pair(CompletionResult::Success, 0); } diff --git a/lib/AST/RequirementMachine/RequirementMachine.cpp b/lib/AST/RequirementMachine/RequirementMachine.cpp index e7edab1d05fde..e8d6db609706b 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.cpp +++ b/lib/AST/RequirementMachine/RequirementMachine.cpp @@ -92,7 +92,7 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) { std::move(builder.RequirementRules)); auto result = computeCompletion(RewriteSystem::DisallowInvalidRequirements); - checkCompletionResult(*this, result); + checkCompletionResult(*this, result.first); if (Dump) { llvm::dbgs() << "}\n"; @@ -109,7 +109,7 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) { /// Used by RequirementSignatureRequest. /// /// Returns failure if completion fails within the configured number of steps. -CompletionResult +std::pair RequirementMachine::initWithProtocols(ArrayRef protos) { FrontendStatsTracer tracer(Stats, "build-rewrite-system"); @@ -173,7 +173,7 @@ void RequirementMachine::initWithAbstractRequirements( std::move(builder.RequirementRules)); auto result = computeCompletion(RewriteSystem::AllowInvalidRequirements); - checkCompletionResult(*this, result); + checkCompletionResult(*this, result.first); if (Dump) { llvm::dbgs() << "}\n"; @@ -189,7 +189,7 @@ void RequirementMachine::initWithAbstractRequirements( /// Used by InferredGenericSignatureRequest. /// /// Returns failure if completion fails within the configured number of steps. -CompletionResult +std::pair RequirementMachine::initWithWrittenRequirements( ArrayRef genericParams, ArrayRef requirements) { @@ -227,7 +227,11 @@ RequirementMachine::initWithWrittenRequirements( /// Attempt to obtain a confluent rewrite system by iterating the Knuth-Bendix /// completion procedure together with property map construction until fixed /// point. -CompletionResult +/// +/// Returns a pair where the first element is the status. If the status is not +/// CompletionResult::Success, the second element of the pair is the rule ID +/// which triggered failure. +std::pair RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { while (true) { { @@ -244,7 +248,7 @@ RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { } // Check for failure. - if (result != CompletionResult::Success) + if (result.first != CompletionResult::Success) return result; // Check invariants. @@ -269,14 +273,20 @@ RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { // Check new rules added by the property map against configured limits. for (unsigned i = 0; i < rulesAdded; ++i) { const auto &newRule = System.getRule(ruleCount + i); - if (newRule.getDepth() > MaxRuleLength) - return CompletionResult::MaxRuleLength; - if (newRule.getNesting() > MaxConcreteNesting) - return CompletionResult::MaxConcreteNesting; + if (newRule.getDepth() > MaxRuleLength) { + return std::make_pair(CompletionResult::MaxRuleLength, + ruleCount + i); + } + if (newRule.getNesting() > MaxConcreteNesting) { + return std::make_pair(CompletionResult::MaxConcreteNesting, + ruleCount + i); + } } - if (System.getRules().size() > MaxRuleCount) - return CompletionResult::MaxRuleCount; + if (System.getRules().size() > MaxRuleCount) { + return std::make_pair(CompletionResult::MaxRuleCount, + System.getRules().size() - 1); + } // If buildPropertyMap() didn't add any new rules, we are done. if (rulesAdded == 0) @@ -291,7 +301,17 @@ RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) { assert(!Complete); Complete = true; - return CompletionResult::Success; + return std::make_pair(CompletionResult::Success, 0); +} + +std::string RequirementMachine::getRuleAsStringForDiagnostics( + unsigned ruleID) const { + const auto &rule = System.getRule(ruleID); + + std::string result; + llvm::raw_string_ostream out(result); + out << rule; + return out.str(); } bool RequirementMachine::isComplete() const { diff --git a/lib/AST/RequirementMachine/RequirementMachine.h b/lib/AST/RequirementMachine/RequirementMachine.h index 9dcc1a03c97dc..a887520066926 100644 --- a/lib/AST/RequirementMachine/RequirementMachine.h +++ b/lib/AST/RequirementMachine/RequirementMachine.h @@ -88,17 +88,20 @@ class RequirementMachine final { RequirementMachine &operator=(RequirementMachine &&) = delete; void initWithGenericSignature(CanGenericSignature sig); - CompletionResult initWithProtocols(ArrayRef protos); + std::pair + initWithProtocols(ArrayRef protos); void initWithAbstractRequirements( ArrayRef genericParams, ArrayRef requirements); - CompletionResult initWithWrittenRequirements( + std::pair + initWithWrittenRequirements( ArrayRef genericParams, ArrayRef requirements); bool isComplete() const; - CompletionResult computeCompletion(RewriteSystem::ValidityPolicy policy); + std::pair + computeCompletion(RewriteSystem::ValidityPolicy policy); MutableTerm getLongestValidPrefix(const MutableTerm &term) const; @@ -142,6 +145,8 @@ class RequirementMachine final { std::vector computeMinimalGenericSignatureRequirements(); + std::string getRuleAsStringForDiagnostics(unsigned ruleID) const; + bool hadError() const; void verify(const MutableTerm &term) const; diff --git a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp index e364c268c0551..e8dfa0482c716 100644 --- a/lib/AST/RequirementMachine/RequirementMachineRequests.cpp +++ b/lib/AST/RequirementMachine/RequirementMachineRequests.cpp @@ -259,14 +259,19 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator, ctx.getRewriteContext())); auto status = machine->initWithProtocols(component); - if (status != CompletionResult::Success) { + if (status.first != CompletionResult::Success) { // All we can do at this point is diagnose and give each protocol an empty // requirement signature. for (const auto *otherProto : component) { ctx.Diags.diagnose(otherProto->getLoc(), diag::requirement_machine_completion_failed, /*protocol=*/1, - unsigned(status)); + unsigned(status.first)); + + auto rule = machine->getRuleAsStringForDiagnostics(status.second); + ctx.Diags.diagnose(otherProto->getLoc(), + diag::requirement_machine_completion_rule, + rule); if (otherProto != proto) { ctx.evaluator.cacheOutput( @@ -500,11 +505,16 @@ InferredGenericSignatureRequestRQM::evaluate( ctx.getRewriteContext())); auto status = machine->initWithWrittenRequirements(genericParams, requirements); - if (status != CompletionResult::Success) { + if (status.first != CompletionResult::Success) { ctx.Diags.diagnose(loc, diag::requirement_machine_completion_failed, /*protocol=*/0, - unsigned(status)); + unsigned(status.first)); + + auto rule = machine->getRuleAsStringForDiagnostics(status.second); + ctx.Diags.diagnose(loc, + diag::requirement_machine_completion_rule, + rule); auto result = GenericSignature::get(genericParams, {}); return GenericSignatureWithError(result, /*hadError=*/true); diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index feb980f5df6e7..a20b9da257fa8 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -325,8 +325,9 @@ class RewriteSystem final { /// Pairs of rules which have already been checked for overlap. llvm::DenseSet> CheckedOverlaps; - CompletionResult computeConfluentCompletion(unsigned maxRuleCount, - unsigned maxRuleLength); + std::pair + computeConfluentCompletion(unsigned maxRuleCount, + unsigned maxRuleLength); void simplifyLeftHandSides(); diff --git a/test/Generics/infinite_concrete_type.swift b/test/Generics/infinite_concrete_type.swift index 95f3a2ee9bb14..4f40d7bb731e2 100644 --- a/test/Generics/infinite_concrete_type.swift +++ b/test/Generics/infinite_concrete_type.swift @@ -3,6 +3,7 @@ class G {} protocol P1 { // expected-error {{cannot build rewrite system for protocol; concrete nesting limit exceeded}} +// expected-note@-1 {{failed rewrite rule is [P1:A].[concrete: G>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> with <[P1:A]>] => [P1:A]}} associatedtype A where A == G associatedtype B where B == G } diff --git a/test/Generics/non_confluent.swift b/test/Generics/non_confluent.swift index 89def46e8c3f2..8bc2476bacd4f 100644 --- a/test/Generics/non_confluent.swift +++ b/test/Generics/non_confluent.swift @@ -1,12 +1,14 @@ // RUN: %target-typecheck-verify-swift -requirement-machine-protocol-signatures=on -requirement-machine-inferred-signatures=on protocol ABA // expected-error {{cannot build rewrite system for protocol; rule length limit exceeded}} +// expected-note@-1 {{failed rewrite rule is [ABA:A].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:A] => [ABA:A].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B].[ABA:B]}} where A.B == A.B.A { // expected-error *{{is not a member type}} associatedtype A : ABA associatedtype B : ABA } protocol Undecidable // expected-error {{cannot build rewrite system for protocol; rule length limit exceeded}} +// expected-note@-1 {{failed rewrite rule is [Undecidable:A].[Undecidable:B].[Undecidable:D].[Undecidable:C].[Undecidable:C].[Undecidable:C].[Undecidable:E].[Undecidable:B].[Undecidable:A].[Undecidable:A].[Undecidable:E].[Undecidable:C].[Undecidable:E] => [Undecidable:A].[Undecidable:B].[Undecidable:D].[Undecidable:C].[Undecidable:C].[Undecidable:C].[Undecidable:E].[Undecidable:B].[Undecidable:A].[Undecidable:A].[Undecidable:E].[Undecidable:C]}} where A.C == C.A, // expected-error *{{is not a member type}} A.D == D.A, // expected-error *{{is not a member type}} B.C == C.B, // expected-error *{{is not a member type}} @@ -31,9 +33,11 @@ protocol P2 { func foo(_: T) {} // expected-error@-1 {{cannot build rewrite system for generic signature; rule length limit exceeded}} +// expected-note@-2 {{failed rewrite rule is τ_0_0.[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P2] => τ_0_0.[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T]}} extension P1 where Self : P2 {} // expected-error@-1 {{cannot build rewrite system for generic signature; rule length limit exceeded}} +// expected-note@-2 {{failed rewrite rule is τ_0_0.[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P2] => τ_0_0.[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T]}} struct S : P1 { typealias T = S> @@ -41,6 +45,8 @@ struct S : P1 { protocol P3 { // expected-error@-1 {{cannot build rewrite system for protocol; rule length limit exceeded}} +// expected-note@-2 {{failed rewrite rule is [P3:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[concrete: S>>>>>>>>>>> with <[P3:U]>] => [P3:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T].[P1:T]}} + associatedtype T : P1 where T == S // expected-error@-1 {{type 'Self.U' does not conform to protocol 'P1'}} associatedtype U : P1 From 037dc98845aa1c5e9a64e92b03b6d4fdae036c52 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 00:04:27 -0500 Subject: [PATCH 14/25] RequirementMachine: Generalize compare() methods to return None instead of asserting on incomparable symbols --- .../RequirementMachine/HomotopyReduction.cpp | 5 +--- lib/AST/RequirementMachine/KnuthBendix.cpp | 6 ++--- .../MinimalConformances.cpp | 2 +- lib/AST/RequirementMachine/RewriteSystem.cpp | 14 +++++------ lib/AST/RequirementMachine/RewriteSystem.h | 2 +- lib/AST/RequirementMachine/Symbol.cpp | 16 ++++++------- lib/AST/RequirementMachine/Symbol.h | 2 +- lib/AST/RequirementMachine/Term.cpp | 23 +++++++++++-------- lib/AST/RequirementMachine/Term.h | 4 ++-- lib/AST/RequirementMachine/TypeDifference.cpp | 10 ++++---- 10 files changed, 42 insertions(+), 42 deletions(-) diff --git a/lib/AST/RequirementMachine/HomotopyReduction.cpp b/lib/AST/RequirementMachine/HomotopyReduction.cpp index c8002d7e3ed90..e9fca5522c9c9 100644 --- a/lib/AST/RequirementMachine/HomotopyReduction.cpp +++ b/lib/AST/RequirementMachine/HomotopyReduction.cpp @@ -402,11 +402,8 @@ findRuleToDelete(llvm::function_ref isRedundantRuleFn, const auto &otherRule = getRule(found->second); - // If the new rule is conflicting, don't compare the rules at all - // and prefer to delete the new rule. Otherwise, prefer to delete - // the less canonical of the two rules. if (rule.isConflicting() || - rule.compare(otherRule, Context) > 0) { + *rule.compare(otherRule, Context) > 0) { found = pair; } } diff --git a/lib/AST/RequirementMachine/KnuthBendix.cpp b/lib/AST/RequirementMachine/KnuthBendix.cpp index 59127821e907c..1ec2a4b06ec8f 100644 --- a/lib/AST/RequirementMachine/KnuthBendix.cpp +++ b/lib/AST/RequirementMachine/KnuthBendix.cpp @@ -90,7 +90,7 @@ Symbol RewriteContext::mergeAssociatedTypes(Symbol lhs, Symbol rhs) { assert(lhs.getKind() == Symbol::Kind::AssociatedType); assert(rhs.getKind() == Symbol::Kind::AssociatedType); assert(lhs.getName() == rhs.getName()); - assert(lhs.compare(rhs, *this) > 0); + assert(*lhs.compare(rhs, *this) > 0); auto protos = lhs.getProtocols(); auto otherProtos = rhs.getProtocols(); @@ -320,8 +320,8 @@ void RewriteSystem::checkMergedAssociatedType(Term lhs, Term rhs) { // We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs. assert(lhs.back() != mergedSymbol && "Left hand side should not already end with merged symbol?"); - assert(mergedSymbol.compare(rhs.back(), Context) <= 0); - assert(rhs.back().compare(lhs.back(), Context) < 0); + assert(*mergedSymbol.compare(rhs.back(), Context) <= 0); + assert(*rhs.back().compare(lhs.back(), Context) < 0); // If the merge didn't actually produce a new symbol, there is nothing else // to do. diff --git a/lib/AST/RequirementMachine/MinimalConformances.cpp b/lib/AST/RequirementMachine/MinimalConformances.cpp index c69c6b04cfbe4..868084726f425 100644 --- a/lib/AST/RequirementMachine/MinimalConformances.cpp +++ b/lib/AST/RequirementMachine/MinimalConformances.cpp @@ -435,7 +435,7 @@ void MinimalConformances::collectConformanceRules() { if (lhsRule.isExplicit() != rhsRule.isExplicit()) return !lhsRule.isExplicit(); - return lhsRule.getLHS().compare(rhsRule.getLHS(), Context) > 0; + return *lhsRule.getLHS().compare(rhsRule.getLHS(), Context) > 0; }); Context.ConformanceRulesHistogram.add(ConformanceRules.size()); diff --git a/lib/AST/RequirementMachine/RewriteSystem.cpp b/lib/AST/RequirementMachine/RewriteSystem.cpp index d1919a4ecc22a..745734f3a3c81 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.cpp +++ b/lib/AST/RequirementMachine/RewriteSystem.cpp @@ -167,9 +167,9 @@ unsigned Rule::getNesting() const { } /// Linear order on rules; compares LHS followed by RHS. -int Rule::compare(const Rule &other, RewriteContext &ctx) const { - int compare = LHS.compare(other.LHS, ctx); - if (compare != 0) +Optional Rule::compare(const Rule &other, RewriteContext &ctx) const { + Optional compare = LHS.compare(other.LHS, ctx); + if (!compare.hasValue() || *compare != 0) return compare; return RHS.compare(other.RHS, ctx); @@ -415,8 +415,8 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs, // If the left hand side and right hand side are already equivalent, we're // done. - int result = lhs.compare(rhs, Context); - if (result == 0) { + Optional result = lhs.compare(rhs, Context); + if (*result == 0) { // If this rule is a consequence of existing rules, add a homotopy // generator. if (path) { @@ -436,12 +436,12 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs, // Orient the two terms so that the left hand side is greater than the // right hand side. - if (result < 0) { + if (*result < 0) { std::swap(lhs, rhs); loop.invert(); } - assert(lhs.compare(rhs, Context) > 0); + assert(*lhs.compare(rhs, Context) > 0); if (Debug.contains(DebugFlags::Add)) { llvm::dbgs() << "## Simplified and oriented rule " << lhs << " => " << rhs << "\n\n"; diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index a20b9da257fa8..f4047ec4542cd 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -183,7 +183,7 @@ class Rule final { unsigned getNesting() const; - int compare(const Rule &other, RewriteContext &ctx) const; + Optional compare(const Rule &other, RewriteContext &ctx) const; void dump(llvm::raw_ostream &out) const; diff --git a/lib/AST/RequirementMachine/Symbol.cpp b/lib/AST/RequirementMachine/Symbol.cpp index ee7ddca0e178d..2be68e2561663 100644 --- a/lib/AST/RequirementMachine/Symbol.cpp +++ b/lib/AST/RequirementMachine/Symbol.cpp @@ -498,7 +498,8 @@ ArrayRef Symbol::getRootProtocols() const { llvm_unreachable("Bad root symbol"); } -/// Linear order on symbols. +/// Linear order on symbols, returning -1, 0, 1 or None if the symbols are +/// incomparable. /// /// First, we order different kinds as follows, from smallest to largest: /// @@ -548,8 +549,8 @@ ArrayRef Symbol::getRootProtocols() const { /// * For concrete conformance symbols with distinct protocols, we compare /// the protocols. /// -/// All other symbol kinds are incomparable. -int Symbol::compare(Symbol other, RewriteContext &ctx) const { +/// All other symbol kinds are incomparable, in which case we return None. +Optional Symbol::compare(Symbol other, RewriteContext &ctx) const { // Exit early if the symbols are equal. if (Ptr == other.Ptr) return 0; @@ -638,8 +639,8 @@ int Symbol::compare(Symbol other, RewriteContext &ctx) const { auto term = getSubstitutions()[i]; auto otherTerm = other.getSubstitutions()[i]; - result = term.compare(otherTerm, ctx); - if (result != 0) + Optional result = term.compare(otherTerm, ctx); + if (!result.hasValue() || *result != 0) return result; } @@ -647,10 +648,7 @@ int Symbol::compare(Symbol other, RewriteContext &ctx) const { } // We don't support comparing arbitrary concrete types. - llvm::errs() << "Cannot compare concrete types yet\n"; - llvm::errs() << "LHS: " << *this << "\n"; - llvm::errs() << "RHS: " << other << "\n"; - abort(); + return None; } } diff --git a/lib/AST/RequirementMachine/Symbol.h b/lib/AST/RequirementMachine/Symbol.h index 32cfeff0e8352..99f4aab0bcbd7 100644 --- a/lib/AST/RequirementMachine/Symbol.h +++ b/lib/AST/RequirementMachine/Symbol.h @@ -223,7 +223,7 @@ class Symbol final { ArrayRef getRootProtocols() const; - int compare(Symbol other, RewriteContext &ctx) const; + Optional compare(Symbol other, RewriteContext &ctx) const; Symbol withConcreteSubstitutions( ArrayRef substitutions, diff --git a/lib/AST/RequirementMachine/Term.cpp b/lib/AST/RequirementMachine/Term.cpp index a24fd426546a5..cd0eb02ef496d 100644 --- a/lib/AST/RequirementMachine/Term.cpp +++ b/lib/AST/RequirementMachine/Term.cpp @@ -128,9 +128,10 @@ bool Term::containsUnresolvedSymbols() const { /// /// This is used to implement Term::compare() and MutableTerm::compare() /// below. -static int shortlexCompare(const Symbol *lhsBegin, const Symbol *lhsEnd, - const Symbol *rhsBegin, const Symbol *rhsEnd, - RewriteContext &ctx) { +static Optional +shortlexCompare(const Symbol *lhsBegin, const Symbol *lhsEnd, + const Symbol *rhsBegin, const Symbol *rhsEnd, + RewriteContext &ctx) { unsigned lhsSize = (lhsEnd - lhsBegin); unsigned rhsSize = (rhsEnd - rhsBegin); if (lhsSize != rhsSize) @@ -143,8 +144,8 @@ static int shortlexCompare(const Symbol *lhsBegin, const Symbol *lhsEnd, ++lhsBegin; ++rhsBegin; - int result = lhs.compare(rhs, ctx); - if (result != 0) { + Optional result = lhs.compare(rhs, ctx); + if (!result.hasValue() || *result != 0) { assert(lhs != rhs); return result; } @@ -155,13 +156,17 @@ static int shortlexCompare(const Symbol *lhsBegin, const Symbol *lhsEnd, return 0; } -/// Shortlex order on terms. -int Term::compare(Term other, RewriteContext &ctx) const { +/// Shortlex order on terms. Returns None if the terms are identical except +/// for an incomparable superclass or concrete type symbol at the end. +Optional +Term::compare(Term other, RewriteContext &ctx) const { return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx); } -/// Shortlex order on mutable terms. -int MutableTerm::compare(const MutableTerm &other, RewriteContext &ctx) const { +/// Shortlex order on mutable terms. Returns None if the terms are identical +/// except for an incomparable superclass or concrete type symbol at the end. +Optional +MutableTerm::compare(const MutableTerm &other, RewriteContext &ctx) const { return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx); } diff --git a/lib/AST/RequirementMachine/Term.h b/lib/AST/RequirementMachine/Term.h index b3ca34dbdb0e6..77943c1d51397 100644 --- a/lib/AST/RequirementMachine/Term.h +++ b/lib/AST/RequirementMachine/Term.h @@ -79,7 +79,7 @@ class Term final { void dump(llvm::raw_ostream &out) const; - int compare(Term other, RewriteContext &ctx) const; + Optional compare(Term other, RewriteContext &ctx) const; friend bool operator==(Term lhs, Term rhs) { return lhs.Ptr == rhs.Ptr; @@ -144,7 +144,7 @@ class MutableTerm final { Symbols.append(from, to); } - int compare(const MutableTerm &other, RewriteContext &ctx) const; + Optional compare(const MutableTerm &other, RewriteContext &ctx) const; bool empty() const { return Symbols.empty(); } diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index bc0b39d6c9438..86c5870f3322a 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -63,7 +63,7 @@ void TypeDifference::verify(RewriteContext &ctx) const { for (const auto &pair : SameTypes) { auto first = LHS.getSubstitutions()[pair.first]; - VERIFY(first.compare(pair.second, ctx) > 0, "Order violation"); + VERIFY(*first.compare(pair.second, ctx) > 0, "Order violation"); VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitutions"); } @@ -131,8 +131,8 @@ namespace { auto lhsTerm = LHSSubstitutions[lhsIndex]; auto rhsTerm = RHSSubstitutions[rhsIndex]; - int compare = lhsTerm.compare(rhsTerm, Context); - if (compare < 0) { + Optional compare = lhsTerm.compare(rhsTerm, Context); + if (*compare < 0) { SameTypesOnLHS.emplace_back(rhsIndex, lhsTerm); } else if (compare > 0) { SameTypesOnRHS.emplace_back(lhsIndex, rhsTerm); @@ -190,14 +190,14 @@ namespace { for (const auto &pair : SameTypesOnLHS) { auto first = RHSSubstitutions[pair.first]; - VERIFY(first.compare(pair.second, Context) > 0, "Order violation"); + VERIFY(*first.compare(pair.second, Context) > 0, "Order violation"); VERIFY(rhsVisited.insert(pair.first).second, "Duplicate substitution"); } for (const auto &pair : SameTypesOnRHS) { auto first = LHSSubstitutions[pair.first]; - VERIFY(first.compare(pair.second, Context) > 0, "Order violation"); + VERIFY(*first.compare(pair.second, Context) > 0, "Order violation"); VERIFY(lhsVisited.insert(pair.first).second, "Duplicate substitution"); } From 06d4770adb4dfe51c8f65c165631ce0c05387779 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 00:26:10 -0500 Subject: [PATCH 15/25] RequirementMachine: Teach homotopy reduction to better deal with incomparable rules --- .../RequirementMachine/HomotopyReduction.cpp | 136 +++++++++++++----- lib/AST/RequirementMachine/RewriteSystem.h | 5 +- 2 files changed, 101 insertions(+), 40 deletions(-) diff --git a/lib/AST/RequirementMachine/HomotopyReduction.cpp b/lib/AST/RequirementMachine/HomotopyReduction.cpp index e9fca5522c9c9..6fea8bd6fc85f 100644 --- a/lib/AST/RequirementMachine/HomotopyReduction.cpp +++ b/lib/AST/RequirementMachine/HomotopyReduction.cpp @@ -340,7 +340,7 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, } /// Find a rule to delete by looking through all loops for rewrite rules appearing -/// once in empty context. Returns a redundant rule to delete if one was found, +/// once in empty context. Returns a pair consisting of a loop ID and a rule ID, /// otherwise returns None. /// /// Minimization performs three passes over the rewrite system. @@ -353,9 +353,8 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, /// 3) Finally, redundant conformance rules are deleted, with /// \p redundantConformances equal to the set of conformance rules that are /// not minimal conformances. -Optional RewriteSystem:: -findRuleToDelete(llvm::function_ref isRedundantRuleFn, - RewritePath &replacementPath) { +Optional> RewriteSystem:: +findRuleToDelete(llvm::function_ref isRedundantRuleFn) { SmallVector, 2> redundancyCandidates; for (unsigned loopID : indices(Loops)) { auto &loop = Loops[loopID]; @@ -392,6 +391,9 @@ findRuleToDelete(llvm::function_ref isRedundantRuleFn, if (rule.isPermanent()) continue; + // Homotopy reduction runs multiple passes with different filters to + // prioritize the deletion of certain rules ahead of others. Apply + // the filter now. if (!isRedundantRuleFn(ruleID)) continue; @@ -400,47 +402,76 @@ findRuleToDelete(llvm::function_ref isRedundantRuleFn, continue; } - const auto &otherRule = getRule(found->second); - - if (rule.isConflicting() || - *rule.compare(otherRule, Context) > 0) { - found = pair; + if (Debug.contains(DebugFlags::HomotopyReduction)) { + llvm::dbgs() << "** Candidate " << rule << " from loop #" + << pair.first << "\n"; } - } - if (!found) - return None; + // 'rule' is the candidate rule; 'otherRule' is the best rule to eliminate + // we've found so far. + const auto &otherRule = getRule(found->second); - unsigned loopID = found->first; - unsigned ruleID = found->second; - assert(replacementPath.empty()); + unsigned ruleNesting = rule.getNesting(); + unsigned otherRuleNesting = otherRule.getNesting(); + + // If both rules are concrete type requirements, first compare nesting + // depth. This breaks the tie when we have two rules that each imply + // the other via an induced rule that comes from a protocol. + // + // For example, + // + // T == G + // U == Int + // + // Where T == G is implied elsewhere. + if (ruleNesting > 0 && otherRuleNesting > 0) { + if (ruleNesting > otherRuleNesting) { + found = pair; + continue; + } else if (otherRuleNesting > ruleNesting) { + continue; + } + } - auto &loop = Loops[loopID]; - replacementPath = loop.Path.splitCycleAtRule(ruleID); + // Otherwise, perform a shortlex comparison on (LHS, RHS). + Optional comparison = rule.compare(otherRule, Context); + if (!comparison.hasValue()) { + // Two rules (T.[C] => T) and (T.[C'] => T) are incomparable if + // C and C' are superclass, concrete type or concrete conformance + // symbols. + // + // This should only arise in two limited situations: + // - The new rule was marked invalid due to a conflict. + // - The new rule was substitution-simplified. + // + // In both cases, the new rule becomes the new candidate for + // elimination. + if (!rule.isConflicting() && !rule.isSubstitutionSimplified()) { + llvm::errs() << "Incomparable rules in homotopy reduction:\n"; + llvm::errs() << "- Candidate rule: " << rule << "\n"; + llvm::errs() << "- Best rule so far: " << otherRule << "\n"; + abort(); + } - loop.markDeleted(); + found = pair; + continue; + } - auto &rule = getRule(ruleID); - rule.markRedundant(); + if (*comparison > 0) { + // Otherwise, if the new rule is less canonical than the best one so + // far, it becomes the new candidate for elimination. + found = pair; + continue; + } + } - return ruleID; + return found; } /// Delete a rewrite rule that is known to be redundant, replacing all /// occurrences of the rule in all loops with the replacement path. void RewriteSystem::deleteRule(unsigned ruleID, const RewritePath &replacementPath) { - if (Debug.contains(DebugFlags::HomotopyReduction)) { - const auto &rule = getRule(ruleID); - llvm::dbgs() << "* Deleting rule "; - rule.dump(llvm::dbgs()); - llvm::dbgs() << " (#" << ruleID << ")\n"; - llvm::dbgs() << "* Replacement path: "; - MutableTerm mutTerm(rule.getLHS()); - replacementPath.dump(llvm::dbgs(), mutTerm, *this); - llvm::dbgs() << "\n"; - } - // Replace all occurrences of the rule with the replacement path in // all remaining rewrite loops. for (auto &loop : Loops) { @@ -466,15 +497,34 @@ void RewriteSystem::deleteRule(unsigned ruleID, void RewriteSystem::performHomotopyReduction( llvm::function_ref isRedundantRuleFn) { while (true) { - RewritePath replacementPath; - auto optRuleID = findRuleToDelete(isRedundantRuleFn, - replacementPath); + auto optPair = findRuleToDelete(isRedundantRuleFn); // If no redundant rules remain which can be eliminated by this pass, stop. - if (!optRuleID) + if (!optPair) return; - deleteRule(*optRuleID, replacementPath); + unsigned loopID = optPair->first; + unsigned ruleID = optPair->second; + + auto &loop = Loops[loopID]; + auto replacementPath = loop.Path.splitCycleAtRule(ruleID); + + loop.markDeleted(); + + auto &rule = getRule(ruleID); + + if (Debug.contains(DebugFlags::HomotopyReduction)) { + llvm::dbgs() << "** Deleting rule " << rule << " from loop #" + << loopID << "\n"; + llvm::dbgs() << "* Replacement path: "; + MutableTerm mutTerm(getRule(ruleID).getLHS()); + replacementPath.dump(llvm::dbgs(), mutTerm, *this); + llvm::dbgs() << "\n"; + } + + rule.markRedundant(); + + deleteRule(ruleID, replacementPath); } } @@ -494,6 +544,10 @@ void RewriteSystem::minimizeRewriteSystem() { // - Eliminate all LHS-simplified non-conformance rules. // - Eliminate all RHS-simplified and substitution-simplified rules. // - Eliminate all rules with unresolved symbols. + if (Debug.contains(DebugFlags::HomotopyReduction)) { + llvm::dbgs() << "\nFirst pass: simplified and unresolved rules\n\n"; + } + performHomotopyReduction([&](unsigned ruleID) -> bool { const auto &rule = getRule(ruleID); @@ -522,6 +576,10 @@ void RewriteSystem::minimizeRewriteSystem() { computeMinimalConformances(redundantConformances); // Second pass: Eliminate all non-minimal conformance rules. + if (Debug.contains(DebugFlags::HomotopyReduction)) { + llvm::dbgs() << "\nSecond pass: non-minimal conformance rules\n\n"; + } + performHomotopyReduction([&](unsigned ruleID) -> bool { const auto &rule = getRule(ruleID); @@ -533,6 +591,10 @@ void RewriteSystem::minimizeRewriteSystem() { }); // Third pass: Eliminate all other redundant non-conformance rules. + if (Debug.contains(DebugFlags::HomotopyReduction)) { + llvm::dbgs() << "\nThird pass: all other redundant rules\n\n"; + } + performHomotopyReduction([&](unsigned ruleID) -> bool { const auto &rule = getRule(ruleID); diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index f4047ec4542cd..5cfc74854acba 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -461,9 +461,8 @@ class RewriteSystem final { void propagateExplicitBits(); - Optional - findRuleToDelete(llvm::function_ref isRedundantRuleFn, - RewritePath &replacementPath); + Optional> + findRuleToDelete(llvm::function_ref isRedundantRuleFn); void deleteRule(unsigned ruleID, const RewritePath &replacementPath); From 00d226fd2fa589d17d031e5eb55fdf77473f14db Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 03:31:21 -0500 Subject: [PATCH 16/25] RequirementMachine: Store the base term in the TypeDifference --- lib/AST/RequirementMachine/PropertyMap.cpp | 18 ++++++----- lib/AST/RequirementMachine/PropertyMap.h | 3 +- .../PropertyUnification.cpp | 3 +- lib/AST/RequirementMachine/RewriteSystem.h | 7 ++--- lib/AST/RequirementMachine/TypeDifference.cpp | 30 +++++++++---------- lib/AST/RequirementMachine/TypeDifference.h | 14 +++++---- 6 files changed, 41 insertions(+), 34 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyMap.cpp b/lib/AST/RequirementMachine/PropertyMap.cpp index df351a83a0231..ff36d6e042c49 100644 --- a/lib/AST/RequirementMachine/PropertyMap.cpp +++ b/lib/AST/RequirementMachine/PropertyMap.cpp @@ -407,7 +407,7 @@ void PropertyMap::buildPropertyMap() { /// Otherwise returns an index which can be passed to /// RewriteSystem::getTypeDifference(). Optional -PropertyMap::concretelySimplifySubstitutions(Symbol symbol, +PropertyMap::concretelySimplifySubstitutions(Term baseTerm, Symbol symbol, RewritePath *path) const { assert(symbol.hasSubstitutions()); @@ -503,13 +503,12 @@ PropertyMap::concretelySimplifySubstitutions(Symbol symbol, return None; } - auto difference = buildTypeDifference(symbol, sameTypes, concreteTypes, + auto difference = buildTypeDifference(baseTerm, symbol, + sameTypes, concreteTypes, Context); assert(difference.LHS != difference.RHS); - unsigned differenceID = System.recordTypeDifference(difference.LHS, - difference.RHS, - difference); + unsigned differenceID = System.recordTypeDifference(difference); // All simplified substitutions are now on the primary stack. Collect them to // produce the new term. @@ -529,13 +528,16 @@ void PropertyMap::concretelySimplifyLeftHandSideSubstitutions() const { rule.isSubstitutionSimplified()) continue; - auto symbol = rule.getLHS().back(); - if (!symbol.hasSubstitutions()) + auto optSymbol = rule.isPropertyRule(); + if (!optSymbol || !optSymbol->hasSubstitutions()) continue; + auto symbol = *optSymbol; + RewritePath path; - auto differenceID = concretelySimplifySubstitutions(symbol, &path); + auto differenceID = concretelySimplifySubstitutions( + rule.getRHS(), symbol, &path); if (!differenceID) continue; diff --git a/lib/AST/RequirementMachine/PropertyMap.h b/lib/AST/RequirementMachine/PropertyMap.h index ea346953cbc8f..a2d31171d7155 100644 --- a/lib/AST/RequirementMachine/PropertyMap.h +++ b/lib/AST/RequirementMachine/PropertyMap.h @@ -284,7 +284,8 @@ class PropertyMap { RequirementKind requirementKind, Symbol concreteConformanceSymbol) const; - Optional concretelySimplifySubstitutions(Symbol symbol, + Optional concretelySimplifySubstitutions(Term baseTerm, + Symbol symbol, RewritePath *path) const; void concretelySimplifyLeftHandSideSubstitutions() const; diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index f2224b5c3f042..5ed672f7a0487 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -436,7 +436,8 @@ void PropertyMap::addConcreteTypeProperty( Optional lhsDifferenceID; Optional rhsDifferenceID; - bool conflict = System.computeTypeDifference(*props->ConcreteType, property, + bool conflict = System.computeTypeDifference(key, + *props->ConcreteType, property, lhsDifferenceID, rhsDifferenceID); diff --git a/lib/AST/RequirementMachine/RewriteSystem.h b/lib/AST/RequirementMachine/RewriteSystem.h index 5cfc74854acba..6cfe2fd26e4a8 100644 --- a/lib/AST/RequirementMachine/RewriteSystem.h +++ b/lib/AST/RequirementMachine/RewriteSystem.h @@ -422,15 +422,14 @@ class RewriteSystem final { /// The map's values are indices into the vector. The map is used for /// uniquing, then the index is returned and lookups are performed into /// the vector. - llvm::DenseMap, unsigned> DifferenceMap; + llvm::DenseMap, unsigned> DifferenceMap; std::vector Differences; public: - unsigned recordTypeDifference(Symbol lhs, Symbol rhs, - const TypeDifference &difference); + unsigned recordTypeDifference(const TypeDifference &difference); bool - computeTypeDifference(Symbol lhs, Symbol rhs, + computeTypeDifference(Term term, Symbol lhs, Symbol rhs, Optional &lhsDifferenceID, Optional &rhsDifferenceID); diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index 86c5870f3322a..dc5167debe025 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -26,6 +26,7 @@ using namespace swift; using namespace rewriting; void TypeDifference::dump(llvm::raw_ostream &out) const { + llvm::errs() << "Base term: " << BaseTerm << "\n"; llvm::errs() << "LHS: " << LHS << "\n"; llvm::errs() << "RHS: " << RHS << "\n"; @@ -253,7 +254,7 @@ namespace { TypeDifference swift::rewriting::buildTypeDifference( - Symbol symbol, + Term baseTerm, Symbol symbol, const llvm::SmallVector, 1> &sameTypes, const llvm::SmallVector, 1> &concreteTypes, RewriteContext &ctx) { @@ -325,17 +326,16 @@ swift::rewriting::buildTypeDifference( llvm_unreachable("Bad symbol kind"); }(); - return {symbol, resultSymbol, sameTypes, concreteTypes}; + return {baseTerm, symbol, resultSymbol, sameTypes, concreteTypes}; } unsigned -RewriteSystem::recordTypeDifference(Symbol lhs, Symbol rhs, - const TypeDifference &difference) { - assert(lhs == difference.LHS); - assert(rhs == difference.RHS); - assert(lhs != rhs); +RewriteSystem::recordTypeDifference(const TypeDifference &difference) { + assert(difference.LHS != difference.RHS); - auto key = std::make_pair(lhs, rhs); + auto key = std::make_tuple(difference.BaseTerm, + difference.LHS, + difference.RHS); auto found = DifferenceMap.find(key); if (found != DifferenceMap.end()) return found->second; @@ -380,7 +380,7 @@ const TypeDifference &RewriteSystem::getTypeDifference(unsigned index) const { /// See the comment at the top of TypeDifference in TypeDifference.h for a /// description of the actual transformations. bool -RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, +RewriteSystem::computeTypeDifference(Term baseTerm, Symbol lhs, Symbol rhs, Optional &lhsDifferenceID, Optional &rhsDifferenceID) { assert(lhs.getKind() == rhs.getKind()); @@ -404,13 +404,13 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, matcher.verify(); - auto lhsMeetRhs = buildTypeDifference(lhs, + auto lhsMeetRhs = buildTypeDifference(baseTerm, lhs, matcher.SameTypesOnRHS, matcher.ConcreteTypesOnRHS, Context); lhsMeetRhs.verify(Context); - auto rhsMeetLhs = buildTypeDifference(rhs, + auto rhsMeetLhs = buildTypeDifference(baseTerm, rhs, matcher.SameTypesOnLHS, matcher.ConcreteTypesOnLHS, Context); @@ -437,7 +437,7 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, // The meet operation should be idempotent. { // (LHS ∧ (LHS ∧ RHS)) == (LHS ∧ RHS) - auto lhsMeetLhsMeetRhs = buildTypeDifference(lhs, + auto lhsMeetLhsMeetRhs = buildTypeDifference(baseTerm, lhs, lhsMeetRhs.SameTypes, lhsMeetRhs.ConcreteTypes, Context); @@ -460,7 +460,7 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, { // (RHS ∧ (RHS ∧ LHS)) == (RHS ∧ LHS) - auto rhsMeetRhsMeetRhs = buildTypeDifference(rhs, + auto rhsMeetRhsMeetRhs = buildTypeDifference(baseTerm, rhs, rhsMeetLhs.SameTypes, rhsMeetLhs.ConcreteTypes, Context); @@ -484,10 +484,10 @@ RewriteSystem::computeTypeDifference(Symbol lhs, Symbol rhs, #endif if (lhs != lhsMeetRhs.RHS) - lhsDifferenceID = recordTypeDifference(lhs, lhsMeetRhs.RHS, lhsMeetRhs); + lhsDifferenceID = recordTypeDifference(lhsMeetRhs); if (rhs != rhsMeetLhs.RHS) - rhsDifferenceID = recordTypeDifference(rhs, rhsMeetLhs.RHS, rhsMeetLhs); + rhsDifferenceID = recordTypeDifference(rhsMeetLhs); return isConflict; } \ No newline at end of file diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h index 5d1f94c1b764c..a890f88769f47 100644 --- a/lib/AST/RequirementMachine/TypeDifference.h +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -32,12 +32,15 @@ namespace rewriting { class RewriteContext; -/// Describes transformations that turn LHS into RHS. There are two kinds of -/// transformations: +/// Describes transformations that turn LHS into RHS, given that there are a +/// pair of rules (BaseTerm.[LHS] => BaseTerm) and (BaseTerm.[RHS] => BaseTerm). +/// +/// There are two kinds of transformations: /// /// - Replacing a type term T1 with another type term T2, where T2 < T1. /// - Replacing a type term T1 with a concrete type C2. struct TypeDifference { + Term BaseTerm; Symbol LHS; Symbol RHS; @@ -49,10 +52,11 @@ struct TypeDifference { /// C2 is a concrete type symbol. SmallVector, 1> ConcreteTypes; - TypeDifference(Symbol lhs, Symbol rhs, + TypeDifference(Term baseTerm, Symbol lhs, Symbol rhs, SmallVector, 1> sameTypes, SmallVector, 1> concreteTypes) - : LHS(lhs), RHS(rhs), SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} + : BaseTerm(baseTerm), LHS(lhs), RHS(rhs), + SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} void dump(llvm::raw_ostream &out) const; void verify(RewriteContext &ctx) const; @@ -60,7 +64,7 @@ struct TypeDifference { TypeDifference buildTypeDifference( - Symbol symbol, + Term baseTerm, Symbol symbol, const llvm::SmallVector, 1> &sameTypes, const llvm::SmallVector, 1> &concreteTypes, RewriteContext &ctx); From 2c355de71b3c21e6614dd4b4c329f878681d3ed8 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 03:37:42 -0500 Subject: [PATCH 17/25] RequirementMachine: Introduce RewriteStep::{Left,Right}ConcreteProjection --- .../RequirementMachine/HomotopyReduction.cpp | 34 ++-- .../MinimalConformances.cpp | 2 + lib/AST/RequirementMachine/RewriteLoop.cpp | 163 ++++++++++++++++++ lib/AST/RequirementMachine/RewriteLoop.h | 138 ++++++++++++++- 4 files changed, 318 insertions(+), 19 deletions(-) diff --git a/lib/AST/RequirementMachine/HomotopyReduction.cpp b/lib/AST/RequirementMachine/HomotopyReduction.cpp index 6fea8bd6fc85f..c9b985c0bfcc6 100644 --- a/lib/AST/RequirementMachine/HomotopyReduction.cpp +++ b/lib/AST/RequirementMachine/HomotopyReduction.cpp @@ -99,6 +99,8 @@ RewriteLoop::findRulesAppearingOnceInEmptyContext( case RewriteStep::Decompose: case RewriteStep::Relation: case RewriteStep::DecomposeConcrete: + case RewriteStep::LeftConcreteProjection: + case RewriteStep::RightConcreteProjection: break; } @@ -219,6 +221,8 @@ RewritePath RewritePath::splitCycleAtRule(unsigned ruleID) const { case RewriteStep::Decompose: case RewriteStep::Relation: case RewriteStep::DecomposeConcrete: + case RewriteStep::LeftConcreteProjection: + case RewriteStep::RightConcreteProjection: break; } @@ -280,23 +284,19 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, // Replace this step with the provided path. If this rewrite step has // context, the path's own steps must be re-contextualized. - // Keep track of Decompose/DecomposeConcrete pairs. Any rewrite steps - // in between do not need to be re-contextualized, since they operate - // on new terms that were pushed on the stack by the Decompose or - // DecomposeConcrete operation. - unsigned decomposeCount = 0; + // Keep track of rewrite step pairs which push and pop the stack. Any + // rewrite steps enclosed with a push/pop are not re-contextualized. + unsigned pushCount = 0; auto recontextualizeStep = [&](RewriteStep newStep) { bool inverse = newStep.Inverse ^ step.Inverse; - if ((newStep.Kind == RewriteStep::Decompose || - newStep.Kind == RewriteStep::DecomposeConcrete) && - inverse) { - assert(decomposeCount > 0); - --decomposeCount; + if (newStep.pushesTermsOnStack() && inverse) { + assert(pushCount > 0); + --pushCount; } - if (decomposeCount == 0) { + if (pushCount == 0) { newStep.StartOffset += step.StartOffset; newStep.EndOffset += step.EndOffset; } @@ -304,10 +304,8 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, newStep.Inverse = inverse; newSteps.push_back(newStep); - if ((newStep.Kind == RewriteStep::Decompose || - newStep.Kind == RewriteStep::DecomposeConcrete) && - !inverse) { - ++decomposeCount; + if (newStep.pushesTermsOnStack() && !inverse) { + ++pushCount; } }; @@ -320,8 +318,8 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, recontextualizeStep(newStep); } - // Decompose and DecomposeConcrete steps should come in balanced pairs. - assert(decomposeCount == 0); + // Rewrite steps which push and pop the stack must come in balanced pairs. + assert(pushCount == 0); break; } @@ -330,6 +328,8 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID, case RewriteStep::Decompose: case RewriteStep::Relation: case RewriteStep::DecomposeConcrete: + case RewriteStep::LeftConcreteProjection: + case RewriteStep::RightConcreteProjection: newSteps.push_back(step); break; } diff --git a/lib/AST/RequirementMachine/MinimalConformances.cpp b/lib/AST/RequirementMachine/MinimalConformances.cpp index 868084726f425..bc0f75d34616d 100644 --- a/lib/AST/RequirementMachine/MinimalConformances.cpp +++ b/lib/AST/RequirementMachine/MinimalConformances.cpp @@ -139,6 +139,8 @@ void RewriteLoop::findProtocolConformanceRules( case RewriteStep::Decompose: case RewriteStep::Relation: case RewriteStep::DecomposeConcrete: + case RewriteStep::LeftConcreteProjection: + case RewriteStep::RightConcreteProjection: break; } } diff --git a/lib/AST/RequirementMachine/RewriteLoop.cpp b/lib/AST/RequirementMachine/RewriteLoop.cpp index 09a98dfc4f3d0..9c5c6c6cc0614 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.cpp +++ b/lib/AST/RequirementMachine/RewriteLoop.cpp @@ -136,6 +136,28 @@ void RewriteStep::dump(llvm::raw_ostream &out, out << difference.LHS << " : " << difference.RHS << ")"; break; } + case LeftConcreteProjection: { + evaluator.applyLeftConcreteProjection(*this, system); + + out << "LeftConcrete" << (Inverse ? "In" : "Pro") << "jection("; + + const auto &difference = system.getTypeDifference( + getTypeDifferenceID()); + + out << difference.LHS << " : " << difference.RHS << ")"; + break; + } + case RightConcreteProjection: { + evaluator.applyRightConcreteProjection(*this, system); + + out << "RightConcrete" << (Inverse ? "In" : "Pro") << "jection("; + + const auto &difference = system.getTypeDifference( + getTypeDifferenceID()); + + out << difference.LHS << " : " << difference.RHS << ")"; + break; + } } } @@ -542,6 +564,139 @@ void RewritePathEvaluator::applyDecomposeConcrete(const RewriteStep &step, } } +void +RewritePathEvaluator::applyLeftConcreteProjection(const RewriteStep &step, + const RewriteSystem &system) { + assert(step.Kind == RewriteStep::LeftConcreteProjection); + + const auto &difference = system.getTypeDifference(step.getTypeDifferenceID()); + unsigned index = step.getSubstitutionIndex(); + + MutableTerm leftProjection(difference.LHS.getSubstitutions()[index]); + + MutableTerm leftBaseTerm(difference.BaseTerm); + leftBaseTerm.add(difference.LHS); + + auto bug = [&](StringRef msg) { + llvm::errs() << msg << "\n"; + llvm::errs() << "- StartOffset: " << step.StartOffset << "\n"; + llvm::errs() << "- EndOffset: " << step.EndOffset << "\n"; + llvm::errs() << "- SubstitutionIndex: " << index << "\n"; + llvm::errs() << "- LeftProjection: " << leftProjection << "\n"; + llvm::errs() << "- LeftBaseTerm: " << leftBaseTerm << "\n"; + llvm::errs() << "- DifferenceID: " << step.getTypeDifferenceID() << "\n"; + llvm::errs() << "\nType difference:\n"; + difference.dump(llvm::errs()); + llvm::errs() << ":\n"; + difference.dump(llvm::errs()); + llvm::errs() << "\nEvaluator state:\n"; + dump(llvm::errs()); + abort(); + }; + + if (!step.Inverse) { + const auto &term = getCurrentTerm(); + + MutableTerm subTerm(term.begin() + step.StartOffset, + term.end() - step.EndOffset); + if (subTerm != MutableTerm(leftProjection)) + bug("Incorrect left projection term"); + + Primary.push_back(leftBaseTerm); + } else { + if (Primary.size() < 2) + bug("Too few elements on the primary stack"); + + if (Primary.back() != leftBaseTerm) + bug("Incorrect left base term"); + + Primary.pop_back(); + + const auto &term = getCurrentTerm(); + + MutableTerm subTerm(term.begin() + step.StartOffset, + term.end() - step.EndOffset); + if (subTerm != leftProjection) + bug("Incorrect left projection term"); + } +} + +void +RewritePathEvaluator::applyRightConcreteProjection(const RewriteStep &step, + const RewriteSystem &system) { + assert(step.Kind == RewriteStep::RightConcreteProjection); + + const auto &difference = system.getTypeDifference(step.getTypeDifferenceID()); + unsigned index = step.getSubstitutionIndex(); + + MutableTerm leftProjection(difference.LHS.getSubstitutions()[index]); + auto rightProjection = difference.getReplacementSubstitution(index); + + MutableTerm leftBaseTerm(difference.BaseTerm); + leftBaseTerm.add(difference.LHS); + + MutableTerm rightBaseTerm(difference.BaseTerm); + rightBaseTerm.add(difference.RHS); + + auto bug = [&](StringRef msg) { + llvm::errs() << msg << "\n"; + llvm::errs() << "- StartOffset: " << step.StartOffset << "\n"; + llvm::errs() << "- EndOffset: " << step.EndOffset << "\n"; + llvm::errs() << "- SubstitutionIndex: " << index << "\n"; + llvm::errs() << "- LeftProjection: " << leftProjection << "\n"; + llvm::errs() << "- RightProjection: " << rightProjection << "\n"; + llvm::errs() << "- LeftBaseTerm: " << leftBaseTerm << "\n"; + llvm::errs() << "- RightBaseTerm: " << rightBaseTerm << "\n"; + llvm::errs() << "- DifferenceID: " << step.getTypeDifferenceID() << "\n"; + llvm::errs() << "\nType difference:\n"; + difference.dump(llvm::errs()); + llvm::errs() << ":\n"; + difference.dump(llvm::errs()); + llvm::errs() << "\nEvaluator state:\n"; + dump(llvm::errs()); + abort(); + }; + + if (!step.Inverse) { + auto &term = getCurrentTerm(); + + MutableTerm subTerm(term.begin() + step.StartOffset, + term.end() - step.EndOffset); + + if (subTerm != rightProjection) + bug("Incorrect right projection term"); + + MutableTerm newTerm(term.begin(), term.begin() + step.StartOffset); + newTerm.append(leftProjection); + newTerm.append(term.end() - step.EndOffset, term.end()); + + term = newTerm; + + Primary.push_back(rightBaseTerm); + } else { + if (Primary.size() < 2) + bug("Too few elements on the primary stack"); + + if (Primary.back() != rightBaseTerm) + bug("Incorrect right base term"); + + Primary.pop_back(); + + auto &term = getCurrentTerm(); + + MutableTerm subTerm(term.begin() + step.StartOffset, + term.end() - step.EndOffset); + if (subTerm != leftProjection) + bug("Incorrect left projection term"); + + MutableTerm newTerm(term.begin(), term.begin() + step.StartOffset); + newTerm.append(rightProjection); + newTerm.append(term.end() - step.EndOffset, term.end()); + + term = newTerm; + } +} + void RewritePathEvaluator::apply(const RewriteStep &step, const RewriteSystem &system) { switch (step.Kind) { @@ -568,5 +723,13 @@ void RewritePathEvaluator::apply(const RewriteStep &step, case RewriteStep::DecomposeConcrete: applyDecomposeConcrete(step, system); break; + + case RewriteStep::LeftConcreteProjection: + applyLeftConcreteProjection(step, system); + break; + + case RewriteStep::RightConcreteProjection: + applyRightConcreteProjection(step, system); + break; } } diff --git a/lib/AST/RequirementMachine/RewriteLoop.h b/lib/AST/RequirementMachine/RewriteLoop.h index 0efdb9e70ae4e..58208dd9ec17c 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.h +++ b/lib/AST/RequirementMachine/RewriteLoop.h @@ -93,6 +93,9 @@ struct RewriteStep { /// T.[concrete: C<...> with ] /// /// The Arg field encodes the number of substitutions. + /// + /// Used by RewriteSystem::simplifyLeftHandSideSubstitutions() and + /// PropertyMap::concretelySimplifyLeftHandSideSubstitutions(). Decompose, /// @@ -151,7 +154,66 @@ struct RewriteStep { /// term ending with the TypeDifference RHS at the top of the primary stack: /// /// T.[concrete: C'<...> with ] - DecomposeConcrete + /// + /// Used by PropertyMap::concretelySimplifyLeftHandSideSubstitutions() and + /// PropertyMap::processTypeDifference(). + DecomposeConcrete, + + /// For decomposing the left hand side of an induced rule in concrete type + /// unification, using a TypeDifference that has been computed previously. + /// + /// The Arg field is a TypeDifference ID together with a substitution index + /// of the TypeDifference LHS which identifies the induced rule. + /// + /// Say the TypeDifference LHS is [concrete: C<...> with ], and + /// say the TypeDifference RHS is [concrete: C'<...> with ]. + /// + /// Note that the LHS and RHS may have a different number of substitutions. + /// + /// Furthermore, let T be the base term of the TypeDifference, meaning that + /// the TypeDifference was derived from a pair of concrete type rules + /// (T.[LHS] => T) and (T.[RHS] => T). + /// + /// If not inverted: the top of the primary stack must be the term Xn, + /// where n is the substitution index of the type difference. + /// + /// Then, the term T.[LHS] is pushed on the primary stack. + /// + /// If inverted: the top of the primary stack must be T.[LHS], which is + /// popped. The next term must be the term Xn. + /// + /// Used by buildRewritePathForInducedRule() in PropertyMap.cpp. + LeftConcreteProjection, + + /// For introducing the right hand side of an induced rule in concrete type + /// unification, using a TypeDifference that has been computed previously. + /// + /// If not inverted: the top of the primary stack must be the term f(Xn), + /// where n is the substitution index of the type difference. There are + /// three cases: + /// + /// - The substitution index appears in the SameTypes list of the + /// TypeDifference. In this case, f(Xn) is the right hand side of the + /// entry in the SameTypes list. + /// + /// - The substitution index appears in the ConcreteTypes list of the + /// TypeDifference. In this case, f(Xn) is Xn.[concrete: D] where D + /// is the right hand side of the entry in the ConcreteTypes list. + /// + /// - The substitution index does not appear in either list, in which case + /// it is unchanged and f(Xn) == Xn. + /// + /// The term f(Xn) is replaced with the original substitution Xn at the + /// top of the primary stack. + /// + /// Then, the term T.[RHS] is pushed on the primary stack. + /// + /// If inverted: the top of the primary stack must be T.[RHS], which is + /// popped. The next term must be the term f(Xn), which is replaced with + /// Xn. + /// + /// Used by buildRewritePathForInducedRule() in PropertyMap.cpp. + RightConcreteProjection }; /// The rewrite step kind. @@ -178,7 +240,16 @@ struct RewriteStep { /// /// If Kind is Relation, the relation index returned from /// RewriteSystem::recordRelation(). - unsigned Arg : 16; + /// + /// If Kind is DecomposeConcrete, the type difference ID returend from + /// RewriteSystem::recordTypeDifference(). + /// + /// If Kind is LeftConcreteProjection or RightConcreteProjection, the + /// type difference returend from RewriteSystem::recordTypeDifference() + /// in the most significant 16 bits, together with the substitution index + /// in the least significant 16 bits. See getConcreteProjectionArg(), + /// getTypeDifference() and getSubstitutionIndex(). + unsigned Arg; RewriteStep(StepKind kind, unsigned startOffset, unsigned endOffset, unsigned arg, bool inverse) { @@ -225,10 +296,46 @@ struct RewriteStep { /*arg=*/differenceID, inverse); } + static RewriteStep forLeftConcreteProjection(unsigned differenceID, + unsigned substitutionIndex, + bool inverse) { + unsigned arg = getConcreteProjectionArg(differenceID, substitutionIndex); + return RewriteStep(LeftConcreteProjection, + /*startOffset=*/0, /*endOffset=*/0, + arg, inverse); + } + + static RewriteStep forRightConcreteProjection(unsigned differenceID, + unsigned substitutionIndex, + bool inverse) { + unsigned arg = getConcreteProjectionArg(differenceID, substitutionIndex); + return RewriteStep(RightConcreteProjection, + /*startOffset=*/0, /*endOffset=*/0, + arg, inverse); + } + bool isInContext() const { return StartOffset > 0 || EndOffset > 0; } + bool pushesTermsOnStack() const { + switch (Kind) { + case RewriteStep::Rule: + case RewriteStep::PrefixSubstitutions: + case RewriteStep::Relation: + case RewriteStep::Shift: + return false; + + case RewriteStep::Decompose: + case RewriteStep::DecomposeConcrete: + case RewriteStep::LeftConcreteProjection: + case RewriteStep::RightConcreteProjection: + return true; + } + + llvm_unreachable("Bad step kind"); + } + void invert() { Inverse = !Inverse; } @@ -238,9 +345,30 @@ struct RewriteStep { return Arg; } + unsigned getTypeDifferenceID() const { + assert(Kind == RewriteStep::LeftConcreteProjection || + Kind == RewriteStep::RightConcreteProjection); + return (Arg >> 16) & 0xffff; + } + + unsigned getSubstitutionIndex() const { + assert(Kind == RewriteStep::LeftConcreteProjection || + Kind == RewriteStep::RightConcreteProjection); + return Arg & 0xffff; + } + void dump(llvm::raw_ostream &out, RewritePathEvaluator &evaluator, const RewriteSystem &system) const; + +private: + static unsigned getConcreteProjectionArg(unsigned differenceID, + unsigned substitutionIndex) { + assert(differenceID <= 0xffff); + assert(substitutionIndex <= 0xffff); + + return (differenceID << 16) | substitutionIndex; + } }; /// Records a sequence of zero or more rewrite rules applied to a term. @@ -420,6 +548,12 @@ struct RewritePathEvaluator { void applyDecomposeConcrete(const RewriteStep &step, const RewriteSystem &system); + void applyLeftConcreteProjection(const RewriteStep &step, + const RewriteSystem &system); + + void applyRightConcreteProjection(const RewriteStep &step, + const RewriteSystem &system); + void dump(llvm::raw_ostream &out) const; }; From fcd467ad4cb6d4b45f5b9fa1eadc0b8948f59fc3 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 03:53:16 -0500 Subject: [PATCH 18/25] RequirementMachine: Factor out TypeDifference::getReplacementSubstitution() --- lib/AST/RequirementMachine/RewriteLoop.cpp | 26 ++----------------- lib/AST/RequirementMachine/TypeDifference.cpp | 22 ++++++++++++++++ lib/AST/RequirementMachine/TypeDifference.h | 2 ++ 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/lib/AST/RequirementMachine/RewriteLoop.cpp b/lib/AST/RequirementMachine/RewriteLoop.cpp index 9c5c6c6cc0614..585420d09c1e4 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.cpp +++ b/lib/AST/RequirementMachine/RewriteLoop.cpp @@ -496,28 +496,6 @@ void RewritePathEvaluator::applyDecomposeConcrete(const RewriteStep &step, auto substitutions = difference.LHS.getSubstitutions(); - auto getReplacementSubstitution = [&](unsigned n) -> MutableTerm { - for (const auto &pair : difference.SameTypes) { - if (pair.first == n) { - // Given a transformation Xn -> Xn', return the term Xn'. - return MutableTerm(pair.second); - } - } - - for (const auto &pair : difference.ConcreteTypes) { - if (pair.first == n) { - // Given a transformation Xn -> [concrete: D], return the - // return Xn.[concrete: D]. - MutableTerm result(substitutions[n]); - result.add(pair.second); - return result; - } - } - - // Otherwise return the original substitution Xn. - return MutableTerm(substitutions[n]); - }; - if (!step.Inverse) { auto &term = getCurrentTerm(); @@ -531,7 +509,7 @@ void RewritePathEvaluator::applyDecomposeConcrete(const RewriteStep &step, term = newTerm; for (unsigned n : indices(substitutions)) - Primary.push_back(getReplacementSubstitution(n)); + Primary.push_back(difference.getReplacementSubstitution(n)); } else { unsigned numSubstitutions = substitutions.size(); @@ -541,7 +519,7 @@ void RewritePathEvaluator::applyDecomposeConcrete(const RewriteStep &step, for (unsigned n : indices(substitutions)) { const auto &otherSubstitution = *(Primary.end() - numSubstitutions + n); - auto expectedSubstitution = getReplacementSubstitution(n); + auto expectedSubstitution = difference.getReplacementSubstitution(n); if (otherSubstitution != expectedSubstitution) { llvm::errs() << "Got: " << otherSubstitution << "\n"; llvm::errs() << "Expected: " << expectedSubstitution << "\n"; diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index dc5167debe025..d4e798a6d25d7 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -25,6 +25,28 @@ using namespace swift; using namespace rewriting; +MutableTerm TypeDifference::getReplacementSubstitution(unsigned index) const { + for (const auto &pair : SameTypes) { + if (pair.first == index) { + // Given a transformation Xn -> Xn', return the term Xn'. + return MutableTerm(pair.second); + } + } + + for (const auto &pair : ConcreteTypes) { + if (pair.first == index) { + // Given a transformation Xn -> [concrete: D], return the + // return Xn.[concrete: D]. + MutableTerm result(LHS.getSubstitutions()[index]); + result.add(pair.second); + return result; + } + } + + // Otherwise return the original substitution Xn. + return MutableTerm(LHS.getSubstitutions()[index]); +} + void TypeDifference::dump(llvm::raw_ostream &out) const { llvm::errs() << "Base term: " << BaseTerm << "\n"; llvm::errs() << "LHS: " << LHS << "\n"; diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h index a890f88769f47..a3e2189870c7d 100644 --- a/lib/AST/RequirementMachine/TypeDifference.h +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -58,6 +58,8 @@ struct TypeDifference { : BaseTerm(baseTerm), LHS(lhs), RHS(rhs), SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} + MutableTerm getReplacementSubstitution(unsigned index) const; + void dump(llvm::raw_ostream &out) const; void verify(RewriteContext &ctx) const; }; From 60db9174e695d58e10562150fe28359d1a2f5303 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 04:05:34 -0500 Subject: [PATCH 19/25] RequirementMachine: Factor out PropertyMap::processTypeDifference() --- lib/AST/RequirementMachine/PropertyMap.h | 5 + .../PropertyUnification.cpp | 93 +++++++++++-------- 2 files changed, 57 insertions(+), 41 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyMap.h b/lib/AST/RequirementMachine/PropertyMap.h index a2d31171d7155..6e37281129942 100644 --- a/lib/AST/RequirementMachine/PropertyMap.h +++ b/lib/AST/RequirementMachine/PropertyMap.h @@ -243,6 +243,11 @@ class PropertyMap { void addProperty(Term key, Symbol property, unsigned ruleID); + void processTypeDifference(const TypeDifference &difference, + unsigned differenceID, + unsigned lhsRuleID, + unsigned rhsRuleID); + void addConformanceProperty(Term key, Symbol property, unsigned ruleID); void addLayoutProperty(Term key, Symbol property, unsigned ruleID); void addSuperclassProperty(Term key, Symbol property, unsigned ruleID); diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 5ed672f7a0487..d8744d034887d 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -391,6 +391,48 @@ void PropertyMap::addSuperclassProperty( } } +/// Record induced rules from the given type difference. +void PropertyMap::processTypeDifference(const TypeDifference &difference, + unsigned differenceID, + unsigned lhsRuleID, + unsigned rhsRuleID) { + bool debug = Debug.contains(DebugFlags::ConcreteUnification); + + if (debug) { + difference.dump(llvm::dbgs()); + } + + for (const auto &pair : difference.SameTypes) { + // Both sides are type parameters; add a same-type requirement. + MutableTerm lhsTerm(difference.LHS.getSubstitutions()[pair.first]); + MutableTerm rhsTerm(pair.second); + + if (debug) { + llvm::dbgs() << "%% Induced rule " << lhsTerm + << " == " << rhsTerm << "\n"; + } + + // FIXME: Need a rewrite path here. + System.addRule(lhsTerm, rhsTerm); + } + + for (const auto &pair : difference.ConcreteTypes) { + // A type parameter is equated with a concrete type; add a concrete + // type requirement. + MutableTerm rhsTerm(difference.LHS.getSubstitutions()[pair.first]); + MutableTerm lhsTerm(rhsTerm); + lhsTerm.add(pair.second); + + if (debug) { + llvm::dbgs() << "%% Induced rule " << lhsTerm + << " == " << rhsTerm << "\n"; + } + + // FIXME: Need a rewrite path here. + System.addRule(lhsTerm, rhsTerm); + } +} + /// When a type parameter has two concrete types, we have to unify the /// type constructor arguments. /// @@ -450,43 +492,6 @@ void PropertyMap::addConcreteTypeProperty( return; } - // Record induced rules from the given type difference. - auto processTypeDifference = [&](const TypeDifference &difference) { - if (debug) { - difference.dump(llvm::dbgs()); - } - - for (const auto &pair : difference.SameTypes) { - // Both sides are type parameters; add a same-type requirement. - MutableTerm lhsTerm(difference.LHS.getSubstitutions()[pair.first]); - MutableTerm rhsTerm(pair.second); - - if (debug) { - llvm::dbgs() << "%% Induced rule " << lhsTerm - << " == " << rhsTerm << "\n"; - } - - // FIXME: Need a rewrite path here. - System.addRule(lhsTerm, rhsTerm); - } - - for (const auto &pair : difference.ConcreteTypes) { - // A type parameter is equated with a concrete type; add a concrete - // type requirement. - MutableTerm rhsTerm(difference.LHS.getSubstitutions()[pair.first]); - MutableTerm lhsTerm(rhsTerm); - lhsTerm.add(pair.second); - - if (debug) { - llvm::dbgs() << "%% Induced rule " << lhsTerm - << " == " << rhsTerm << "\n"; - } - - // FIXME: Need a rewrite path here. - System.addRule(lhsTerm, rhsTerm); - } - }; - // Handle the case where (LHS ∧ RHS) is distinct from both LHS and RHS: // - First, record a new rule. // - Next, process the LHS -> (LHS ∧ RHS) difference. @@ -510,6 +515,8 @@ void PropertyMap::addConcreteTypeProperty( << " == " << rhsTerm << "\n"; } + // This rule does not need a rewrite path because it will be related + // to the existing rule in concretelySimplifyLeftHandSideSubstitutions(). System.addRule(lhsTerm, rhsTerm); } @@ -529,11 +536,13 @@ void PropertyMap::addConcreteTypeProperty( // Process LHS -> (LHS ∧ RHS). if (checkRulePairOnce(*props->ConcreteTypeRule, newRuleID)) - processTypeDifference(lhsDifference); + processTypeDifference(lhsDifference, *lhsDifferenceID, + *props->ConcreteTypeRule, newRuleID); // Process RHS -> (LHS ∧ RHS). if (checkRulePairOnce(ruleID, newRuleID)) - processTypeDifference(rhsDifference); + processTypeDifference(rhsDifference, *rhsDifferenceID, + ruleID, newRuleID); // The new property is more specific, so update ConcreteType and // ConcreteTypeRule. @@ -552,7 +561,8 @@ void PropertyMap::addConcreteTypeProperty( assert(property == lhsDifference.RHS); if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) - processTypeDifference(lhsDifference); + processTypeDifference(lhsDifference, *lhsDifferenceID, + *props->ConcreteTypeRule, ruleID); // The new property is more specific, so update ConcreteType and // ConcreteTypeRule. @@ -571,7 +581,8 @@ void PropertyMap::addConcreteTypeProperty( assert(*props->ConcreteType == rhsDifference.RHS); if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) - processTypeDifference(rhsDifference); + processTypeDifference(rhsDifference, *rhsDifferenceID, + ruleID, *props->ConcreteTypeRule); // The new property is less specific, so ConcreteType and ConcreteTypeRule // remain unchanged. From 9ee702026f8862c45db323f54de558f65e0a220c Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 04:10:05 -0500 Subject: [PATCH 20/25] RequirementMachine: Factor out TypeDifference::getOriginalSubstitution() --- lib/AST/RequirementMachine/PropertyUnification.cpp | 4 ++-- lib/AST/RequirementMachine/RewriteLoop.cpp | 4 ++-- lib/AST/RequirementMachine/TypeDifference.cpp | 12 ++++++++---- lib/AST/RequirementMachine/TypeDifference.h | 1 + 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index d8744d034887d..950f893350677 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -404,7 +404,7 @@ void PropertyMap::processTypeDifference(const TypeDifference &difference, for (const auto &pair : difference.SameTypes) { // Both sides are type parameters; add a same-type requirement. - MutableTerm lhsTerm(difference.LHS.getSubstitutions()[pair.first]); + auto lhsTerm = difference.getOriginalSubstitution(pair.first); MutableTerm rhsTerm(pair.second); if (debug) { @@ -419,7 +419,7 @@ void PropertyMap::processTypeDifference(const TypeDifference &difference, for (const auto &pair : difference.ConcreteTypes) { // A type parameter is equated with a concrete type; add a concrete // type requirement. - MutableTerm rhsTerm(difference.LHS.getSubstitutions()[pair.first]); + auto rhsTerm = difference.getOriginalSubstitution(pair.first); MutableTerm lhsTerm(rhsTerm); lhsTerm.add(pair.second); diff --git a/lib/AST/RequirementMachine/RewriteLoop.cpp b/lib/AST/RequirementMachine/RewriteLoop.cpp index 585420d09c1e4..1e3694f414fc4 100644 --- a/lib/AST/RequirementMachine/RewriteLoop.cpp +++ b/lib/AST/RequirementMachine/RewriteLoop.cpp @@ -550,7 +550,7 @@ RewritePathEvaluator::applyLeftConcreteProjection(const RewriteStep &step, const auto &difference = system.getTypeDifference(step.getTypeDifferenceID()); unsigned index = step.getSubstitutionIndex(); - MutableTerm leftProjection(difference.LHS.getSubstitutions()[index]); + auto leftProjection = difference.getOriginalSubstitution(index); MutableTerm leftBaseTerm(difference.BaseTerm); leftBaseTerm.add(difference.LHS); @@ -607,7 +607,7 @@ RewritePathEvaluator::applyRightConcreteProjection(const RewriteStep &step, const auto &difference = system.getTypeDifference(step.getTypeDifferenceID()); unsigned index = step.getSubstitutionIndex(); - MutableTerm leftProjection(difference.LHS.getSubstitutions()[index]); + auto leftProjection = difference.getOriginalSubstitution(index); auto rightProjection = difference.getReplacementSubstitution(index); MutableTerm leftBaseTerm(difference.BaseTerm); diff --git a/lib/AST/RequirementMachine/TypeDifference.cpp b/lib/AST/RequirementMachine/TypeDifference.cpp index d4e798a6d25d7..e79fcf205f4c1 100644 --- a/lib/AST/RequirementMachine/TypeDifference.cpp +++ b/lib/AST/RequirementMachine/TypeDifference.cpp @@ -25,6 +25,10 @@ using namespace swift; using namespace rewriting; +MutableTerm TypeDifference::getOriginalSubstitution(unsigned index) const { + return MutableTerm(LHS.getSubstitutions()[index]); +} + MutableTerm TypeDifference::getReplacementSubstitution(unsigned index) const { for (const auto &pair : SameTypes) { if (pair.first == index) { @@ -37,14 +41,14 @@ MutableTerm TypeDifference::getReplacementSubstitution(unsigned index) const { if (pair.first == index) { // Given a transformation Xn -> [concrete: D], return the // return Xn.[concrete: D]. - MutableTerm result(LHS.getSubstitutions()[index]); + auto result = getOriginalSubstitution(index); result.add(pair.second); return result; } } // Otherwise return the original substitution Xn. - return MutableTerm(LHS.getSubstitutions()[index]); + return getOriginalSubstitution(index); } void TypeDifference::dump(llvm::raw_ostream &out) const { @@ -53,12 +57,12 @@ void TypeDifference::dump(llvm::raw_ostream &out) const { llvm::errs() << "RHS: " << RHS << "\n"; for (const auto &pair : SameTypes) { - out << "- " << LHS.getSubstitutions()[pair.first] << " (#"; + out << "- " << getOriginalSubstitution(pair.first) << " (#"; out << pair.first << ") -> " << pair.second << "\n"; } for (const auto &pair : ConcreteTypes) { - out << "- " << LHS.getSubstitutions()[pair.first] << " (#"; + out << "- " << getOriginalSubstitution(pair.first) << " (#"; out << pair.first << ") -> " << pair.second << "\n"; } } diff --git a/lib/AST/RequirementMachine/TypeDifference.h b/lib/AST/RequirementMachine/TypeDifference.h index a3e2189870c7d..adba85ac1098f 100644 --- a/lib/AST/RequirementMachine/TypeDifference.h +++ b/lib/AST/RequirementMachine/TypeDifference.h @@ -58,6 +58,7 @@ struct TypeDifference { : BaseTerm(baseTerm), LHS(lhs), RHS(rhs), SameTypes(sameTypes), ConcreteTypes(concreteTypes) {} + MutableTerm getOriginalSubstitution(unsigned index) const; MutableTerm getReplacementSubstitution(unsigned index) const; void dump(llvm::raw_ostream &out) const; From 70876171e628c993e659a1ad455da8bcdb8c6fb8 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Feb 2022 04:29:06 -0500 Subject: [PATCH 21/25] RequirementMachine: Build rewrite paths for concrete unification induced rules --- .../PropertyUnification.cpp | 163 +++++++++++++++--- 1 file changed, 140 insertions(+), 23 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 950f893350677..4b24aa94c92b0 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -43,6 +43,15 @@ bool PropertyMap::checkRulePairOnce(unsigned firstRuleID, /// Given a key T, a rule (V.[p1] => V) where T == U.V, and a property [p2] /// where [p1] < [p2], record a rule (T.[p2] => T) that is induced by /// the original rule (V.[p1] => V). +/// +/// This is used to define rewrite loops for relating pairs of rules where +/// one implies another: +/// +/// - a more specific layout constraint implies a general layout constraint5 +/// - a superclass bound implies a layout constraint +/// - a concrete type that is a class implies a superclass bound +/// - a concrete type that is a class implies a layout constraint +/// static void recordRelation(Term key, unsigned lhsRuleID, Symbol rhsProperty, @@ -391,7 +400,122 @@ void PropertyMap::addSuperclassProperty( } } -/// Record induced rules from the given type difference. +/// Build a rewrite path for a rule induced by concrete type unification. +/// +/// Consider two concrete type rules (T.[LHS] => T) and (T.[RHS] => T), a +/// TypeDifference describing the transformation from LHS to RHS, and the +/// index of a substitution Xn from [C] which is transformed into its +/// replacement f(Xn). +/// +/// The rewrite path should allow us to eliminate the induced rule +/// (f(Xn) => Xn), so the induced rule will appear without context, and +/// the concrete type rules (T.[LHS] => T) and (T.[RHS] => T) will appear +/// in context. +/// +/// There are two cases: +/// +/// a) The substitution Xn remains a type parameter in [RHS], but becomes +/// a canonical term Xn', so f(Xn) = Xn'. +/// +/// In the first case, the induced rule (Xn => Xn'), described by a +/// rewrite path as follows: +/// +/// Xn +/// Xn' T.[RHS] // RightConcreteProjection(n) pushes T.[RHS] +/// Xn' T // Application of (T.[RHS] => T) in context +/// Xn' T.[LHS] // Application of (T => T.[LHS]) in context +/// Xn' // LeftConcreteProjection(n) pops T.[LHS] +/// +/// Now when this path is composed with a rewrite step for the inverted +/// induced rule (Xn' => Xn), we get a rewrite loop at Xn in which the +/// new rule appears in empty context. +/// +/// b) The substitution Xn becomes a concrete type [D] in [C'], so +/// f(Xn) = Xn.[D]. +/// +/// In the second case, the induced rule is (Xn.[D] => Xn), described +/// by a rewrite path (going in the other direction) as follows: +/// +/// Xn +/// Xn.[D] T.[RHS] // RightConcreteProjection(n) pushes T.[RHS] +/// Xn.[D] T // Application of (T.[RHS] => T) in context +/// Xn.[D] T.[LHS] // Application of (T => T.[LHS]) in context +/// Xn.[D] // LeftConcreteProjection(n) pops T.[LHS] +/// +/// Now when this path is composed with a rewrite step for the induced +/// rule (Xn.[D] => Xn), we get a rewrite loop at Xn in which the +/// new rule appears in empty context. +/// +/// There is a minor complication; the concrete type rules T.[LHS] and +/// T.[RHS] might actually be T.[LHS] and V.[RHS] where V is a suffix of +/// T, so T = U.V for some |U| > 0, (or vice versa). In this case we need +/// an additional step in the middle to prefix the concrete substitutions +/// of [LHS] (or [LHS]) with U. +static void buildRewritePathForInducedRule(unsigned differenceID, + unsigned lhsRuleID, + unsigned rhsRuleID, + unsigned substitutionIndex, + const RewriteSystem &system, + RewritePath &path) { + unsigned lhsLength = system.getRule(lhsRuleID).getRHS().size(); + unsigned rhsLength = system.getRule(rhsRuleID).getRHS().size(); + + unsigned lhsPrefix = 0, rhsPrefix = 0; + if (lhsLength < rhsLength) + lhsPrefix = rhsLength - lhsLength; + if (rhsLength < lhsLength) + rhsPrefix = lhsLength - rhsLength; + + assert(lhsPrefix == 0 || rhsPrefix == 0); + + // Replace f(Xn) with Xn and push T.[RHS] on the stack. + path.add(RewriteStep::forRightConcreteProjection( + differenceID, substitutionIndex, /*inverse=*/false)); + + // If the rule was actually (V.[RHS] => V) with T == U.V for some + // |U| > 0, strip U from the prefix of each substitution of [RHS]. + if (rhsPrefix > 0) { + path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/rhsPrefix, + /*endOffset=*/0, + /*inverse=*/true)); + } + + // Apply the rule (V.[RHS] => V). + path.add(RewriteStep::forRewriteRule( + /*startOffset=*/rhsPrefix, /*endOffset=*/0, + /*ruleID=*/rhsRuleID, /*inverse=*/false)); + + // Apply the inverted rule (V' => V'.[LHS]). + path.add(RewriteStep::forRewriteRule( + /*startOffset=*/lhsPrefix, /*endOffset=*/0, + /*ruleID=*/lhsRuleID, /*inverse=*/true)); + + // If the rule was actually (V.[LHS] => V) with T == U.V for some + // |U| > 0, prefix each substitution of [LHS] with U. + if (lhsPrefix > 0) { + path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/lhsPrefix, + /*endOffset=*/0, + /*inverse=*/false)); + } + + // Pop T.[LHS] from the stack, leaving behind Xn. + path.add(RewriteStep::forLeftConcreteProjection( + differenceID, substitutionIndex, /*inverse=*/true)); +} + +/// Given two concrete type rules (T.[LHS] => T) and (T.[RHS] => T) and +/// TypeDifference describing the transformation from LHS to RHS, +/// record rules for transforming each substitution of LHS into a +/// more canonical type parameter or concrete type from RHS. +/// +/// This also records rewrite paths relating induced rules to the original +/// concrete type rules, since the concrete type rules imply the induced +/// rules and make them redundant. +/// +/// The implication going in the other direction -- where one of the +/// two concrete type rules together with the induced rules implies the +/// other concrete type rule -- is recorded in +/// concretelySimplifyLeftHandSideSubstitutions(). void PropertyMap::processTypeDifference(const TypeDifference &difference, unsigned differenceID, unsigned lhsRuleID, @@ -402,34 +526,22 @@ void PropertyMap::processTypeDifference(const TypeDifference &difference, difference.dump(llvm::dbgs()); } - for (const auto &pair : difference.SameTypes) { - // Both sides are type parameters; add a same-type requirement. - auto lhsTerm = difference.getOriginalSubstitution(pair.first); - MutableTerm rhsTerm(pair.second); + for (unsigned index : indices(difference.LHS.getSubstitutions())) { + auto lhsTerm = difference.getReplacementSubstitution(index); + auto rhsTerm = difference.getOriginalSubstitution(index); - if (debug) { - llvm::dbgs() << "%% Induced rule " << lhsTerm - << " == " << rhsTerm << "\n"; - } - - // FIXME: Need a rewrite path here. - System.addRule(lhsTerm, rhsTerm); - } - - for (const auto &pair : difference.ConcreteTypes) { - // A type parameter is equated with a concrete type; add a concrete - // type requirement. - auto rhsTerm = difference.getOriginalSubstitution(pair.first); - MutableTerm lhsTerm(rhsTerm); - lhsTerm.add(pair.second); + RewritePath path; + buildRewritePathForInducedRule(differenceID, lhsRuleID, rhsRuleID, + index, System, path); if (debug) { llvm::dbgs() << "%% Induced rule " << lhsTerm - << " == " << rhsTerm << "\n"; + << " => " << rhsTerm << " with path "; + path.dump(llvm::dbgs(), lhsTerm, System); + llvm::dbgs() << "\n"; } - // FIXME: Need a rewrite path here. - System.addRule(lhsTerm, rhsTerm); + System.addRule(lhsTerm, rhsTerm, &path); } } @@ -676,6 +788,11 @@ void PropertyMap::addProperty( llvm_unreachable("Bad symbol kind"); } +/// Post-pass to handle unification and conflict checking between pairs of +/// rules of different kinds: +/// +/// - concrete vs superclass +/// - concrete vs layout void PropertyMap::checkConcreteTypeRequirements() { bool debug = Debug.contains(DebugFlags::ConcreteUnification); From 5bb80286e9fd5376a11882a0b15e5fe50090a500 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 5 Feb 2022 21:53:33 -0500 Subject: [PATCH 22/25] RequirementMachine: Factor out a utility for building a rewrite path unifying to concrete type rules --- .../PropertyUnification.cpp | 111 ++++++++---------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 4b24aa94c92b0..7e7d6f70eac9c 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -400,6 +400,51 @@ void PropertyMap::addSuperclassProperty( } } +/// Given two rules (V.[LHS] => V) and (V'.[RHS] => V'), build a rewrite +/// path from T.[RHS] to T.[LHS], where T is the longer of the two terms +/// V and V'. +static void buildRewritePathForUnifier(unsigned lhsRuleID, + unsigned rhsRuleID, + const RewriteSystem &system, + RewritePath &path) { + unsigned lhsLength = system.getRule(lhsRuleID).getRHS().size(); + unsigned rhsLength = system.getRule(rhsRuleID).getRHS().size(); + + unsigned lhsPrefix = 0, rhsPrefix = 0; + if (lhsLength < rhsLength) + lhsPrefix = rhsLength - lhsLength; + if (rhsLength < lhsLength) + rhsPrefix = lhsLength - rhsLength; + + assert(lhsPrefix == 0 || rhsPrefix == 0); + + // If the rule was actually (V.[RHS] => V) with T == U.V for some + // |U| > 0, strip U from the prefix of each substitution of [RHS]. + if (rhsPrefix > 0) { + path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/rhsPrefix, + /*endOffset=*/0, + /*inverse=*/true)); + } + + // Apply the rule (V.[RHS] => V). + path.add(RewriteStep::forRewriteRule( + /*startOffset=*/rhsPrefix, /*endOffset=*/0, + /*ruleID=*/rhsRuleID, /*inverse=*/false)); + + // Apply the inverted rule (V' => V'.[LHS]). + path.add(RewriteStep::forRewriteRule( + /*startOffset=*/lhsPrefix, /*endOffset=*/0, + /*ruleID=*/lhsRuleID, /*inverse=*/true)); + + // If the rule was actually (V.[LHS] => V) with T == U.V for some + // |U| > 0, prefix each substitution of [LHS] with U. + if (lhsPrefix > 0) { + path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/lhsPrefix, + /*endOffset=*/0, + /*inverse=*/false)); + } +} + /// Build a rewrite path for a rule induced by concrete type unification. /// /// Consider two concrete type rules (T.[LHS] => T) and (T.[RHS] => T), a @@ -457,46 +502,11 @@ static void buildRewritePathForInducedRule(unsigned differenceID, unsigned substitutionIndex, const RewriteSystem &system, RewritePath &path) { - unsigned lhsLength = system.getRule(lhsRuleID).getRHS().size(); - unsigned rhsLength = system.getRule(rhsRuleID).getRHS().size(); - - unsigned lhsPrefix = 0, rhsPrefix = 0; - if (lhsLength < rhsLength) - lhsPrefix = rhsLength - lhsLength; - if (rhsLength < lhsLength) - rhsPrefix = lhsLength - rhsLength; - - assert(lhsPrefix == 0 || rhsPrefix == 0); - // Replace f(Xn) with Xn and push T.[RHS] on the stack. path.add(RewriteStep::forRightConcreteProjection( differenceID, substitutionIndex, /*inverse=*/false)); - // If the rule was actually (V.[RHS] => V) with T == U.V for some - // |U| > 0, strip U from the prefix of each substitution of [RHS]. - if (rhsPrefix > 0) { - path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/rhsPrefix, - /*endOffset=*/0, - /*inverse=*/true)); - } - - // Apply the rule (V.[RHS] => V). - path.add(RewriteStep::forRewriteRule( - /*startOffset=*/rhsPrefix, /*endOffset=*/0, - /*ruleID=*/rhsRuleID, /*inverse=*/false)); - - // Apply the inverted rule (V' => V'.[LHS]). - path.add(RewriteStep::forRewriteRule( - /*startOffset=*/lhsPrefix, /*endOffset=*/0, - /*ruleID=*/lhsRuleID, /*inverse=*/true)); - - // If the rule was actually (V.[LHS] => V) with T == U.V for some - // |U| > 0, prefix each substitution of [LHS] with U. - if (lhsPrefix > 0) { - path.add(RewriteStep::forPrefixSubstitutions(/*prefix=*/lhsPrefix, - /*endOffset=*/0, - /*inverse=*/false)); - } + buildRewritePathForUnifier(lhsRuleID, rhsRuleID, system, path); // Pop T.[LHS] from the stack, leaving behind Xn. path.add(RewriteStep::forLeftConcreteProjection( @@ -718,33 +728,10 @@ void PropertyMap::addConcreteTypeProperty( // // Since the new rule appears without context, it becomes redundant. if (checkRulePairOnce(*props->ConcreteTypeRule, ruleID)) { - const auto &otherRule = System.getRule(*props->ConcreteTypeRule); - assert(otherRule.getRHS().size() < key.size()); - - unsigned prefixLength = (key.size() - otherRule.getRHS().size()); - - // Build a loop that rewrites U.V back into itself via the two rules, - // with a prefix substitutions step in the middle. RewritePath path; - - // Add a rewrite step U.(V => V.[concrete: G<...> with ]). - path.add(RewriteStep::forRewriteRule(/*startOffset=*/prefixLength, - /*endOffset=*/0, - *props->ConcreteTypeRule, - /*inverse=*/true)); - - // Add a rewrite step to prefix 'U' to the substitutions. - path.add(RewriteStep::forPrefixSubstitutions(/*length=*/prefixLength, - /*endOffset=*/0, - /*inverse=*/false)); - - // Add a rewrite step (U.V.[concrete: G<...> with ] => U.V). - path.add(RewriteStep::forRewriteRule(/*startOffset=*/0, - /*endOffset=*/0, - ruleID, - /*inverse=*/false)); - - System.recordRewriteLoop(MutableTerm(key), path); + buildRewritePathForUnifier(*props->ConcreteTypeRule, ruleID, System, + path); + System.recordRewriteLoop(MutableTerm(rule.getLHS()), path); } } } From 9e234a09c6dfc6c8bd7eb0a3375e4b362346abf9 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 5 Feb 2022 22:57:17 -0500 Subject: [PATCH 23/25] RequirementMachine: Record rewrite loop relating concrete type rules in processTypeDifference() We can't actually rely on concretelySimplifyLeftHandSideSubstitutions() to do this for us, because the less-simplified rule (the LHS rule) might only apply to a suffix of the base term. --- .../PropertyUnification.cpp | 68 ++++++++++++++++--- test/Generics/unify_concrete_types_1.swift | 1 - 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 7e7d6f70eac9c..982c1727871e5 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -522,10 +522,8 @@ static void buildRewritePathForInducedRule(unsigned differenceID, /// concrete type rules, since the concrete type rules imply the induced /// rules and make them redundant. /// -/// The implication going in the other direction -- where one of the -/// two concrete type rules together with the induced rules implies the -/// other concrete type rule -- is recorded in -/// concretelySimplifyLeftHandSideSubstitutions(). +/// Finally, builds a rewrite loop relating the two concrete type rules +/// via the induced rules. void PropertyMap::processTypeDifference(const TypeDifference &difference, unsigned differenceID, unsigned lhsRuleID, @@ -536,23 +534,75 @@ void PropertyMap::processTypeDifference(const TypeDifference &difference, difference.dump(llvm::dbgs()); } - for (unsigned index : indices(difference.LHS.getSubstitutions())) { + RewritePath unificationPath; + + auto substitutions = difference.LHS.getSubstitutions(); + + // The term is at the top of the primary stack. Push all substitutions onto + // the primary stack. + unificationPath.add(RewriteStep::forDecompose(substitutions.size(), + /*inverse=*/false)); + + // Move all substitutions but the first one to the secondary stack. + for (unsigned i = 1; i < substitutions.size(); ++i) + unificationPath.add(RewriteStep::forShift(/*inverse=*/false)); + + for (unsigned index : indices(substitutions)) { + // Move the next substitution from the secondary stack to the primary stack. + if (index != 0) + unificationPath.add(RewriteStep::forShift(/*inverse=*/true)); + auto lhsTerm = difference.getReplacementSubstitution(index); auto rhsTerm = difference.getOriginalSubstitution(index); - RewritePath path; + RewritePath inducedRulePath; buildRewritePathForInducedRule(differenceID, lhsRuleID, rhsRuleID, - index, System, path); + index, System, inducedRulePath); if (debug) { llvm::dbgs() << "%% Induced rule " << lhsTerm << " => " << rhsTerm << " with path "; - path.dump(llvm::dbgs(), lhsTerm, System); + inducedRulePath.dump(llvm::dbgs(), lhsTerm, System); llvm::dbgs() << "\n"; } - System.addRule(lhsTerm, rhsTerm, &path); + System.addRule(lhsTerm, rhsTerm, &inducedRulePath); + + // Build a path from rhsTerm (the original substitution) to + // lhsTerm (the replacement substitution). + MutableTerm mutRhsTerm(rhsTerm); + (void) System.simplify(mutRhsTerm, &unificationPath); + + RewritePath lhsPath; + MutableTerm mutLhsTerm(lhsTerm); + (void) System.simplify(mutLhsTerm, &lhsPath); + + assert(mutLhsTerm == mutRhsTerm && "Terms should be joinable"); + lhsPath.invert(); + unificationPath.append(lhsPath); } + + // All simplified substitutions are now on the primary stack. Collect them to + // produce the new term. + unificationPath.add(RewriteStep::forDecomposeConcrete(differenceID, + /*inverse=*/true)); + + // We now have a unification path from T.[RHS] to T.[LHS] using the + // newly-recorded induced rules. Close the loop with a path from + // T.[RHS] to R.[LHS] via the concrete type rules being unified. + buildRewritePathForUnifier(lhsRuleID, rhsRuleID, System, unificationPath); + + // Record a rewrite loop at T.[LHS]. + MutableTerm basepoint(difference.BaseTerm); + basepoint.add(difference.LHS); + System.recordRewriteLoop(basepoint, unificationPath); + + // Optimization: If the LHS rule applies to the entire base term and not + // a suffix, mark it substitution-simplified so that we can skip recording + // the same rewrite loop in concretelySimplifyLeftHandSideSubstitutions(). + auto &lhsRule = System.getRule(lhsRuleID); + if (lhsRule.getRHS() == difference.BaseTerm) + lhsRule.markSubstitutionSimplified(); } /// When a type parameter has two concrete types, we have to unify the diff --git a/test/Generics/unify_concrete_types_1.swift b/test/Generics/unify_concrete_types_1.swift index f6783710aff12..17a70f40d3c75 100644 --- a/test/Generics/unify_concrete_types_1.swift +++ b/test/Generics/unify_concrete_types_1.swift @@ -28,5 +28,4 @@ struct MergeTest { // CHECK: [P1:X] => { concrete_type: [concrete: Foo<τ_0_0, τ_0_1> with <[P1:Y1], [P1:Z1]>] } // CHECK: [P2:X] => { concrete_type: [concrete: Foo<τ_0_0, τ_0_1> with <[P2:Y2], [P2:Z2]>] } // CHECK: τ_0_0 => { conforms_to: [P1 P2] } -// CHECK: τ_0_0.[P1:X] => { concrete_type: [concrete: Foo<τ_0_0, τ_0_1> with <τ_0_0.[P1:Y1], τ_0_0.[P1:Z1]>] } // CHECK: } From 868e48cb7a35dbd43cfbfaeb6c2761c089f74548 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sun, 6 Feb 2022 00:37:23 -0500 Subject: [PATCH 24/25] RequirementMachine: Mark rules as simplified in PropertyMap::addConcreteTypeProperty() --- lib/AST/RequirementMachine/PropertyUnification.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/AST/RequirementMachine/PropertyUnification.cpp b/lib/AST/RequirementMachine/PropertyUnification.cpp index 982c1727871e5..2d86be9ce0d24 100644 --- a/lib/AST/RequirementMachine/PropertyUnification.cpp +++ b/lib/AST/RequirementMachine/PropertyUnification.cpp @@ -629,7 +629,7 @@ void PropertyMap::addConcreteTypeProperty( Term key, Symbol property, unsigned ruleID) { auto *props = getOrCreateProperties(key); - const auto &rule = System.getRule(ruleID); + auto &rule = System.getRule(ruleID); assert(rule.getRHS() == key); bool debug = Debug.contains(DebugFlags::ConcreteUnification); @@ -782,6 +782,8 @@ void PropertyMap::addConcreteTypeProperty( buildRewritePathForUnifier(*props->ConcreteTypeRule, ruleID, System, path); System.recordRewriteLoop(MutableTerm(rule.getLHS()), path); + + rule.markSubstitutionSimplified(); } } } From 729dfc7c799ada08329179ccb2daa6785c5fff27 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sun, 6 Feb 2022 01:10:11 -0500 Subject: [PATCH 25/25] RequirementMachine: Add some tests for concrete type unification --- .../minimize_concrete_unification.swift | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 test/Generics/minimize_concrete_unification.swift diff --git a/test/Generics/minimize_concrete_unification.swift b/test/Generics/minimize_concrete_unification.swift new file mode 100644 index 0000000000000..4c53f044c2e34 --- /dev/null +++ b/test/Generics/minimize_concrete_unification.swift @@ -0,0 +1,173 @@ +// RUN: %target-swift-frontend -typecheck %s -debug-generic-signatures -requirement-machine-protocol-signatures=on 2>&1 | %FileCheck %s + +struct G {} + +// Three requirements where any two imply the third: +// +// a) T == G +// b) T == G +// c) U == Int + +// CHECK-LABEL: minimize_concrete_unification.(file).Pab@ +// CHECK-NEXT: Requirement signature: , Self.[Pab]U == Int> + +protocol Pab { + associatedtype T where T == G, T == G + associatedtype U +} + +// CHECK-LABEL: minimize_concrete_unification.(file).Pac@ +// CHECK-NEXT: Requirement signature: , Self.[Pac]U == Int> + +protocol Pac { + associatedtype T where T == G + associatedtype U where U == Int +} + +// CHECK-LABEL: minimize_concrete_unification.(file).Pbc@ +// CHECK-NEXT: Requirement signature: , Self.[Pbc]U == Int> + +protocol Pbc { + associatedtype T where T == G + associatedtype U where U == Int +} + +// CHECK-LABEL: minimize_concrete_unification.(file).Pabc@ +// CHECK-NEXT: Requirement signature: , Self.[Pabc]U == Int> + +protocol Pabc { + associatedtype T where T == G, T == G + associatedtype U where U == Int +} + +// + +// CHECK-LABEL: minimize_concrete_unification.(file).Pa@ +// CHECK-NEXT: Requirement signature: > + +protocol Pa { + associatedtype T where T == G + associatedtype U +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PaQb@ +// CHECK-NEXT: Requirement signature: + +protocol PaQb { + associatedtype X : Pa where X.T == G +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PaQc@ +// CHECK-NEXT: Requirement signature: + +protocol PaQc { + associatedtype X : Pa where X.U == Int +} + +// + +// CHECK-LABEL: minimize_concrete_unification.(file).Pb@ +// CHECK-NEXT: Requirement signature: > + +protocol Pb { + associatedtype T where T == G + associatedtype U +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PbQa@ +// CHECK-NEXT: Requirement signature: + +protocol PbQa { + associatedtype X : Pb where X.T == G +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PbQc@ +// CHECK-NEXT: Requirement signature: + +protocol PbQc { + associatedtype X : Pb where X.U == Int +} + +// + +// CHECK-LABEL: minimize_concrete_unification.(file).Pc@ +// CHECK-NEXT: Requirement signature: + +protocol Pc { + associatedtype T + associatedtype U where U == Int +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PcQa@ +// CHECK-NEXT: Requirement signature: > + +protocol PcQa { + associatedtype X : Pc where X.T == G +} + +// CHECK-LABEL: minimize_concrete_unification.(file).PcQb@ +// CHECK-NEXT: Requirement signature: > + +protocol PcQb { + associatedtype X : Pc where X.T == G +} + +// + +// CHECK-LABEL: minimize_concrete_unification.(file).Q1@ +// CHECK-NEXT: Requirement signature: + +protocol Q1 { + associatedtype V where V : Pa, V.T == G + associatedtype W +} + +// + +// CHECK-LABEL: minimize_concrete_unification.(file).P1@ +// CHECK-NEXT: Requirement signature: > + +protocol P1 { + associatedtype T + associatedtype U where T == G +} + +// CHECK-LABEL: minimize_concrete_unification.(file).P2@ +// CHECK-NEXT: Requirement signature: > + +protocol P2 { + associatedtype T where T == G + associatedtype U +} + +// CHECK-LABEL: minimize_concrete_unification.(file).P@ +// CHECK-NEXT: Requirement signature: + +protocol P { + associatedtype T + associatedtype U +} + +// CHECK-LABEL: minimize_concrete_unification.(file).R1@ +// CHECK-NEXT: Requirement signature: + +protocol R1 { + // The GSB would drop 'X.T == Int' from the minimal signature. + associatedtype X where X : P, X.T == G, X : Pa +} + +// CHECK-LABEL: minimize_concrete_unification.(file).R2@ +// CHECK-NEXT: Requirement signature: + +protocol R2 { + // The GSB would drop 'X.T == Int' from the minimal signature. + associatedtype X where X : P, X.T == G, X : Pb +} + +// CHECK-LABEL: minimize_concrete_unification.(file).R3@ +// CHECK-NEXT: Requirement signature: + +protocol R3 { + // The GSB would include a redundant 'X.T == Int' in the minimal signature. + associatedtype X where X : Pa, X.T == G, X : Pb +}