diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index 999bc5d83a1ce..f7ec4f155383a 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -913,13 +913,6 @@ namespace swift { /// is for testing purposes. std::vector DebugForbidTypecheckPrefixes; - /// The upper bound to number of sub-expressions unsolved - /// before termination of the shrink phrase of the constraint solver. - unsigned SolverShrinkUnsolvedThreshold = 10; - - /// Disable the shrink phase of the expression type checker. - bool SolverDisableShrink = false; - /// Enable experimental operator designated types feature. bool EnableOperatorDesignatedTypes = false; @@ -935,6 +928,9 @@ namespace swift { /// Allow request evalutation to perform type checking lazily, instead of /// eagerly typechecking source files after parsing. bool EnableLazyTypecheck = false; + + /// Disable the component splitter phase of the expression type checker. + bool SolverDisableSplitter = false; }; /// Options for controlling the behavior of the Clang importer. diff --git a/include/swift/Option/FrontendOptions.td b/include/swift/Option/FrontendOptions.td index 6e9852d2d1e33..d52792fd697ad 100644 --- a/include/swift/Option/FrontendOptions.td +++ b/include/swift/Option/FrontendOptions.td @@ -825,15 +825,17 @@ def downgrade_typecheck_interface_error : Flag<["-"], "downgrade-typecheck-inter def enable_volatile_modules : Flag<["-"], "enable-volatile-modules">, HelpText<"Load Swift modules in memory">; -def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">; +def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">, + HelpText<"Expression type checking timeout, in seconds">; -def solver_scope_threshold_EQ : Joined<["-"], "solver-scope-threshold=">; +def solver_scope_threshold_EQ : Joined<["-"], "solver-scope-threshold=">, + HelpText<"Expression type checking scope limit">; -def solver_trail_threshold_EQ : Joined<["-"], "solver-trail-threshold=">; +def solver_trail_threshold_EQ : Joined<["-"], "solver-trail-threshold=">, + HelpText<"Expression type checking trail change limit">; -def solver_disable_shrink : - Flag<["-"], "solver-disable-shrink">, - HelpText<"Disable the shrink phase of expression type checking">; +def solver_disable_splitter : Flag<["-"], "solver-disable-splitter">, + HelpText<"Disable the component splitter phase of expression type checking">; def disable_constraint_solver_performance_hacks : Flag<["-"], "disable-constraint-solver-performance-hacks">, HelpText<"Disable all the hacks in the constraint solver">; diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index b06079bb2ae03..cb546d27b35cb 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -301,6 +301,11 @@ struct PotentialBindings { Constraint *constraint); void reset(); + + void dump(ConstraintSystem &CS, + TypeVariableType *TypeVar, + llvm::raw_ostream &out, + unsigned indent) const; }; @@ -567,64 +572,27 @@ class BindingSet { /// /// \param inferredBindings The set of all bindings inferred for type /// variables in the workset. - void inferTransitiveBindings( - const llvm::SmallDenseMap - &inferredBindings); + void inferTransitiveBindings(); /// Detect subtype, conversion or equivalence relationship /// between two type variables and attempt to propagate protocol /// requirements down the subtype or equivalence chain. - void inferTransitiveProtocolRequirements( - llvm::SmallDenseMap &inferredBindings); + void inferTransitiveProtocolRequirements(); /// Finalize binding computation for this type variable by /// inferring bindings from context e.g. transitive bindings. /// /// \returns true if finalization successful (which makes binding set viable), /// and false otherwise. - bool finalize( - llvm::SmallDenseMap &inferredBindings); + bool finalize(bool transitive); static BindingScore formBindingScore(const BindingSet &b); - /// Compare two sets of bindings, where \c x < y indicates that - /// \c x is a better set of bindings that \c y. - friend bool operator<(const BindingSet &x, const BindingSet &y) { - auto xScore = formBindingScore(x); - auto yScore = formBindingScore(y); - - if (xScore < yScore) - return true; - - if (yScore < xScore) - return false; - - auto xDefaults = x.getNumViableDefaultableBindings(); - auto yDefaults = y.getNumViableDefaultableBindings(); - - // If there is a difference in number of default types, - // prioritize bindings with fewer of them. - if (xDefaults != yDefaults) - return xDefaults < yDefaults; - - // If neither type variable is a "hole" let's check whether - // there is a subtype relationship between them and prefer - // type variable which represents superclass first in order - // for "subtype" type variable to attempt more bindings later. - // This is required because algorithm can't currently infer - // bindings for subtype transitively through superclass ones. - if (!(std::get<0>(xScore) && std::get<0>(yScore))) { - if (x.Info.isSubtypeOf(y.getTypeVariable())) - return false; - - if (y.Info.isSubtypeOf(x.getTypeVariable())) - return true; - } + bool operator==(const BindingSet &other); - // As a last resort, let's check if the bindings are - // potentially incomplete, and if so, let's de-prioritize them. - return x.isPotentiallyIncomplete() < y.isPotentiallyIncomplete(); - } + /// Compare two sets of bindings, where \c this < other indicates that + /// \c this is a better set of bindings that \c other. + bool operator<(const BindingSet &other); void dump(llvm::raw_ostream &out, unsigned indent) const; diff --git a/include/swift/Sema/ConstraintGraph.h b/include/swift/Sema/ConstraintGraph.h index 5587c1c2de3c0..d68794d75b515 100644 --- a/include/swift/Sema/ConstraintGraph.h +++ b/include/swift/Sema/ConstraintGraph.h @@ -84,9 +84,24 @@ class ConstraintGraphNode { /// as this type variable. ArrayRef getEquivalenceClass() const; - inference::PotentialBindings &getCurrentBindings() { - assert(forRepresentativeVar()); - return Bindings; + inference::PotentialBindings &getPotentialBindings() { + DEBUG_ASSERT(forRepresentativeVar()); + return Potential; + } + + void initBindingSet(); + + inference::BindingSet &getBindingSet() { + ASSERT(hasBindingSet()); + return *Set; + } + + bool hasBindingSet() const { + return Set.has_value(); + } + + void resetBindingSet() { + Set.reset(); } private: @@ -131,15 +146,6 @@ class ConstraintGraphNode { /// Binding Inference { - /// Infer bindings from the given constraint and notify referenced variables - /// about its arrival (if requested). This happens every time a new constraint - /// gets added to a constraint graph node. - void introduceToInference(Constraint *constraint); - - /// Forget about the given constraint. This happens every time a constraint - /// gets removed for a constraint graph. - void retractFromInference(Constraint *constraint); - /// Perform graph updates that must be undone after we bind a fixed type /// to a type variable. void retractFromInference(Type fixedType); @@ -182,8 +188,13 @@ class ConstraintGraphNode { /// The type variable this node represents. TypeVariableType *TypeVar; - /// The set of bindings associated with this type variable. - inference::PotentialBindings Bindings; + /// The potential bindings for this type variable, updated incrementally by + /// the constraint graph. + inference::PotentialBindings Potential; + + /// The binding set for this type variable, computed by + /// determineBestBindings(). + std::optional Set; /// The vector of constraints that mention this type variable, in a stable /// order for iteration. diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 2fa61caf9b2dc..d74420a7707a5 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -5161,7 +5161,9 @@ class ConstraintSystem { /// Get bindings for the given type variable based on current /// state of the constraint system. - BindingSet getBindingsFor(TypeVariableType *typeVar, bool finalize = true); + /// + /// FIXME: Remove this. + BindingSet getBindingsFor(TypeVariableType *typeVar); private: /// Add a constraint to the constraint system. diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp index 483c7566ff8e5..1626b6e407bb9 100644 --- a/lib/Frontend/CompilerInvocation.cpp +++ b/lib/Frontend/CompilerInvocation.cpp @@ -1775,8 +1775,6 @@ static bool ParseTypeCheckerArgs(TypeCheckerOptions &Opts, ArgList &Args, Opts.SolverScopeThreshold); setUnsignedIntegerArgument(OPT_solver_trail_threshold_EQ, Opts.SolverTrailThreshold); - setUnsignedIntegerArgument(OPT_solver_shrink_unsolved_threshold, - Opts.SolverShrinkUnsolvedThreshold); Opts.DebugTimeFunctionBodies |= Args.hasArg(OPT_debug_time_function_bodies); Opts.DebugTimeExpressions |= @@ -1865,8 +1863,8 @@ static bool ParseTypeCheckerArgs(TypeCheckerOptions &Opts, ArgList &Args, Opts.DebugForbidTypecheckPrefixes.push_back(A); } - if (Args.getLastArg(OPT_solver_disable_shrink)) - Opts.SolverDisableShrink = true; + if (Args.getLastArg(OPT_solver_disable_splitter)) + Opts.SolverDisableSplitter = true; if (FrontendOpts.RequestedAction == FrontendOptions::ActionType::Immediate) Opts.DeferToRuntime = true; diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 6c945baf65509..2c76ea0d65a0b 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -31,6 +31,13 @@ using namespace swift; using namespace constraints; using namespace inference; +void ConstraintGraphNode::initBindingSet() { + ASSERT(!hasBindingSet()); + ASSERT(forRepresentativeVar()); + + Set.emplace(CG.getConstraintSystem(), TypeVar, Potential); +} + /// Check whether there exists a type that could be implicitly converted /// to a given type i.e. is the given type is Double or Optional<..> this /// function is going to return true because CGFloat could be converted @@ -278,8 +285,7 @@ bool BindingSet::isPotentiallyIncomplete() const { return false; } -void BindingSet::inferTransitiveProtocolRequirements( - llvm::SmallDenseMap &inferredBindings) { +void BindingSet::inferTransitiveProtocolRequirements() { if (TransitiveProtocols) return; @@ -314,13 +320,13 @@ void BindingSet::inferTransitiveProtocolRequirements( do { auto *currentVar = workList.back().second; - auto cachedBindings = inferredBindings.find(currentVar); - if (cachedBindings == inferredBindings.end()) { + auto &node = CS.getConstraintGraph()[currentVar]; + if (!node.hasBindingSet()) { workList.pop_back(); continue; } - auto &bindings = cachedBindings->getSecond(); + auto &bindings = node.getBindingSet(); // If current variable already has transitive protocol // conformances inferred, there is no need to look deeper @@ -352,11 +358,11 @@ void BindingSet::inferTransitiveProtocolRequirements( if (!equivalenceClass.insert(typeVar)) continue; - auto bindingSet = inferredBindings.find(typeVar); - if (bindingSet == inferredBindings.end()) + auto &node = CS.getConstraintGraph()[typeVar]; + if (!node.hasBindingSet()) continue; - auto &equivalences = bindingSet->getSecond().Info.EquivalentTo; + auto &equivalences = node.getBindingSet().Info.EquivalentTo; for (const auto &eqVar : equivalences) { workList.push_back(eqVar.first); } @@ -367,11 +373,11 @@ void BindingSet::inferTransitiveProtocolRequirements( if (memberVar == currentVar) continue; - auto eqBindings = inferredBindings.find(memberVar); - if (eqBindings == inferredBindings.end()) + auto &node = CS.getConstraintGraph()[memberVar]; + if (!node.hasBindingSet()) continue; - const auto &bindings = eqBindings->getSecond(); + const auto &bindings = node.getBindingSet(); llvm::SmallPtrSet placeholder; // Add any direct protocols from members of the @@ -423,9 +429,9 @@ void BindingSet::inferTransitiveProtocolRequirements( // Propagate inferred protocols to all of the members of the // equivalence class. for (const auto &equivalence : bindings.Info.EquivalentTo) { - auto eqBindings = inferredBindings.find(equivalence.first); - if (eqBindings != inferredBindings.end()) { - auto &bindings = eqBindings->getSecond(); + auto &node = CS.getConstraintGraph()[equivalence.first]; + if (node.hasBindingSet()) { + auto &bindings = node.getBindingSet(); bindings.TransitiveProtocols.emplace(protocolsForEquivalence.begin(), protocolsForEquivalence.end()); } @@ -438,9 +444,7 @@ void BindingSet::inferTransitiveProtocolRequirements( } while (!workList.empty()); } -void BindingSet::inferTransitiveBindings( - const llvm::SmallDenseMap - &inferredBindings) { +void BindingSet::inferTransitiveBindings() { using BindingKind = AllowedBindingKind; // If the current type variable represents a key path root type @@ -450,9 +454,9 @@ void BindingSet::inferTransitiveBindings( auto *locator = TypeVar->getImpl().getLocator(); if (auto *keyPathTy = CS.getType(locator->getAnchor())->getAs()) { - auto keyPathBindings = inferredBindings.find(keyPathTy); - if (keyPathBindings != inferredBindings.end()) { - auto &bindings = keyPathBindings->getSecond(); + auto &node = CS.getConstraintGraph()[keyPathTy]; + if (node.hasBindingSet()) { + auto &bindings = node.getBindingSet(); for (auto &binding : bindings.Bindings) { auto bindingTy = binding.BindingType->lookThroughAllOptionalTypes(); @@ -476,9 +480,9 @@ void BindingSet::inferTransitiveBindings( // transitively used because conversions between generic arguments // are not allowed. if (auto *contextualRootVar = inferredRootTy->getAs()) { - auto rootBindings = inferredBindings.find(contextualRootVar); - if (rootBindings != inferredBindings.end()) { - auto &bindings = rootBindings->getSecond(); + auto &node = CS.getConstraintGraph()[contextualRootVar]; + if (node.hasBindingSet()) { + auto &bindings = node.getBindingSet(); // Don't infer if root is not yet fully resolved. if (bindings.isDelayed()) @@ -507,11 +511,11 @@ void BindingSet::inferTransitiveBindings( } for (const auto &entry : Info.SupertypeOf) { - auto relatedBindings = inferredBindings.find(entry.first); - if (relatedBindings == inferredBindings.end()) + auto &node = CS.getConstraintGraph()[entry.first]; + if (!node.hasBindingSet()) continue; - auto &bindings = relatedBindings->getSecond(); + auto &bindings = node.getBindingSet(); // FIXME: This is a workaround necessary because solver doesn't filter // bindings based on protocol requirements placed on a type variable. @@ -610,9 +614,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability, return keyPathTy; } -bool BindingSet::finalize( - llvm::SmallDenseMap &inferredBindings) { - inferTransitiveBindings(inferredBindings); +bool BindingSet::finalize(bool transitive) { + if (transitive) + inferTransitiveBindings(); determineLiteralCoverage(); @@ -628,8 +632,8 @@ bool BindingSet::finalize( // func foo(_: T) {} // foo(.bar) <- `.bar` should be a static member of `P`. // \endcode - if (!hasViableBindings()) { - inferTransitiveProtocolRequirements(inferredBindings); + if (transitive && !hasViableBindings()) { + inferTransitiveProtocolRequirements(); if (TransitiveProtocols.has_value()) { for (auto *constraint : *TransitiveProtocols) { @@ -956,6 +960,56 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { Literals.insert({protocol, std::move(literal)}); } +bool BindingSet::operator==(const BindingSet &other) { + if (AdjacentVars != other.AdjacentVars) + return false; + + if (Bindings.size() != other.Bindings.size()) + return false; + + for (auto i : indices(Bindings)) { + const auto &x = Bindings[i]; + const auto &y = other.Bindings[i]; + + if (x.BindingType.getPointer() != y.BindingType.getPointer() || + x.Kind != y.Kind) + return false; + } + + if (Literals.size() != other.Literals.size()) + return false; + + for (auto pair : Literals) { + auto found = other.Literals.find(pair.first); + if (found == other.Literals.end()) + return false; + + const auto &x = pair.second; + const auto &y = found->second; + + if (x.Source != y.Source || + x.DefaultType.getPointer() != y.DefaultType.getPointer() || + x.IsDirectRequirement != y.IsDirectRequirement) { + return false; + } + } + + if (Defaults.size() != other.Defaults.size()) + return false; + + for (auto pair : Defaults) { + auto found = other.Defaults.find(pair.first); + if (found == other.Defaults.end() || + pair.second != found->second) + return false; + } + + if (TransitiveProtocols != other.TransitiveProtocols) + return false; + + return true; +} + BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) { // If there are no bindings available but this type // variable represents a closure - let's consider it @@ -976,17 +1030,54 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) { -numNonDefaultableBindings); } +bool BindingSet::operator<(const BindingSet &other) { + auto xScore = formBindingScore(*this); + auto yScore = formBindingScore(other); + + if (xScore < yScore) + return true; + + if (yScore < xScore) + return false; + + auto xDefaults = getNumViableDefaultableBindings(); + auto yDefaults = other.getNumViableDefaultableBindings(); + + // If there is a difference in number of default types, + // prioritize bindings with fewer of them. + if (xDefaults != yDefaults) + return xDefaults < yDefaults; + + // If neither type variable is a "hole" let's check whether + // there is a subtype relationship between them and prefer + // type variable which represents superclass first in order + // for "subtype" type variable to attempt more bindings later. + // This is required because algorithm can't currently infer + // bindings for subtype transitively through superclass ones. + if (!(std::get<0>(xScore) && std::get<0>(yScore))) { + if (Info.isSubtypeOf(other.getTypeVariable())) + return false; + + if (other.Info.isSubtypeOf(getTypeVariable())) + return true; + } + + // As a last resort, let's check if the bindings are + // potentially incomplete, and if so, let's de-prioritize them. + return isPotentiallyIncomplete() < other.isPotentiallyIncomplete(); +} + std::optional ConstraintSystem::determineBestBindings( llvm::function_ref onCandidate) { // Look for potential type variable bindings. - std::optional bestBindings; - llvm::SmallDenseMap cache; + BindingSet *bestBindings = nullptr; // First, let's collect all of the possible bindings. for (auto *typeVar : getTypeVariables()) { - if (!typeVar->getImpl().hasRepresentativeOrFixed()) { - cache.insert({typeVar, getBindingsFor(typeVar, /*finalize=*/false)}); - } + auto &node = CG[typeVar]; + node.resetBindingSet(); + if (!typeVar->getImpl().hasRepresentativeOrFixed()) + node.initBindingSet(); } // Determine whether given type variable with its set of bindings is @@ -1023,11 +1114,12 @@ std::optional ConstraintSystem::determineBestBindings( // Now let's see if we could infer something for related type // variables based on other bindings. for (auto *typeVar : getTypeVariables()) { - auto cachedBindings = cache.find(typeVar); - if (cachedBindings == cache.end()) + auto &node = CG[typeVar]; + if (!node.hasBindingSet()) continue; - auto &bindings = cachedBindings->getSecond(); + auto &bindings = node.getBindingSet(); + // Before attempting to infer transitive bindings let's check // whether there are any viable "direct" bindings associated with // current type variable, if there are none - it means that this type @@ -1040,7 +1132,7 @@ std::optional ConstraintSystem::determineBestBindings( // produce a default type. bool isViable = isViableForRanking(bindings); - if (!bindings.finalize(cache)) + if (!bindings.finalize(true)) continue; if (!bindings || !isViable) @@ -1051,10 +1143,13 @@ std::optional ConstraintSystem::determineBestBindings( // If these are the first bindings, or they are better than what // we saw before, use them instead. if (!bestBindings || bindings < *bestBindings) - bestBindings.emplace(bindings); + bestBindings = &bindings; } - return bestBindings; + if (!bestBindings) + return std::nullopt; + + return std::optional(*bestBindings); } /// Find the set of type variables that are inferable from the given type. @@ -1435,18 +1530,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const { return true; } -BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar, - bool finalize) { +BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) { assert(typeVar->getImpl().getRepresentative(nullptr) == typeVar && "not a representative"); assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type"); - BindingSet bindings(*this, typeVar, CG[typeVar].getCurrentBindings()); - - if (finalize) { - llvm::SmallDenseMap cache; - bindings.finalize(cache); - } + BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings()); + bindings.finalize(false); return bindings; } @@ -2042,6 +2132,48 @@ void PotentialBindings::reset() { AssociatedCodeCompletionToken = ASTNode(); } +void PotentialBindings::dump(ConstraintSystem &cs, + TypeVariableType *typeVar, + llvm::raw_ostream &out, + unsigned indent) const { + PrintOptions PO; + PO.PrintTypesForDebugging = true; + + out << "Potential bindings for "; + typeVar->getImpl().print(out); + out << "\n"; + + out << "[constraints: "; + interleave(Constraints, + [&](Constraint *constraint) { + constraint->print(out, &cs.getASTContext().SourceMgr, indent, + /*skipLocator=*/true); + }, + [&out]() { out << ", "; }); + out << "] "; + + if (!AdjacentVars.empty()) { + out << "[adjacent to: "; + SmallVector> adjacentVars( + AdjacentVars.begin(), AdjacentVars.end()); + llvm::sort(adjacentVars, + [](auto lhs, auto rhs) { + return lhs.first->getID() < rhs.first->getID(); + }); + interleave(adjacentVars, + [&](std::pair pair) { + out << pair.first->getString(PO); + if (pair.first->getImpl().getFixedType(/*record=*/nullptr)) + out << " (fixed)"; + out << " via "; + pair.second->print(out, &cs.getASTContext().SourceMgr, indent, + /*skipLocator=*/true); + }, + [&out]() { out << ", "; }); + out << "] "; + } +} + void BindingSet::forEachLiteralRequirement( llvm::function_ref callback) const { for (const auto &literal : Literals) { @@ -2181,22 +2313,21 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { if (!attributes.empty()) out << "] "; - if (involvesTypeVariables()) { + if (!AdjacentVars.empty()) { out << "[adjacent to: "; - if (AdjacentVars.empty()) { - out << ""; - } else { - SmallVector adjacentVars(AdjacentVars.begin(), - AdjacentVars.end()); - llvm::sort(adjacentVars, - [](const TypeVariableType *lhs, const TypeVariableType *rhs) { + SmallVector adjacentVars(AdjacentVars.begin(), + AdjacentVars.end()); + llvm::sort(adjacentVars, + [](const TypeVariableType *lhs, const TypeVariableType *rhs) { return lhs->getID() < rhs->getID(); - }); - interleave( - adjacentVars, - [&](const auto *typeVar) { out << typeVar->getString(PO); }, - [&out]() { out << ", "; }); - } + }); + interleave(adjacentVars, + [&](auto *typeVar) { + out << typeVar->getString(PO); + if (typeVar->getImpl().getFixedType(/*record=*/nullptr)) + out << " (fixed)"; + }, + [&out]() { out << ", "; }); out << "] "; } @@ -2209,24 +2340,25 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { enum class BindingKind { Exact, Subtypes, Supertypes, Literal }; BindingKind Kind; Type BindingType; - PrintableBinding(BindingKind kind, Type bindingType) - : Kind(kind), BindingType(bindingType) {} + bool Viable; + PrintableBinding(BindingKind kind, Type bindingType, bool viable) + : Kind(kind), BindingType(bindingType), Viable(viable) {} public: static PrintableBinding supertypesOf(Type binding) { - return PrintableBinding{BindingKind::Supertypes, binding}; + return PrintableBinding{BindingKind::Supertypes, binding, true}; } static PrintableBinding subtypesOf(Type binding) { - return PrintableBinding{BindingKind::Subtypes, binding}; + return PrintableBinding{BindingKind::Subtypes, binding, true}; } static PrintableBinding exact(Type binding) { - return PrintableBinding{BindingKind::Exact, binding}; + return PrintableBinding{BindingKind::Exact, binding, true}; } - static PrintableBinding literalDefaultType(Type binding) { - return PrintableBinding{BindingKind::Literal, binding}; + static PrintableBinding literalDefaultType(Type binding, bool viable) { + return PrintableBinding{BindingKind::Literal, binding, viable}; } void print(llvm::raw_ostream &out, const PrintOptions &PO, @@ -2244,7 +2376,10 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { out << "(default type of literal) "; break; } - BindingType.print(out, PO); + if (BindingType) + BindingType.print(out, PO); + if (!Viable) + out << " [literal not viable]"; } }; @@ -2266,10 +2401,11 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { } } for (const auto &literal : Literals) { - if (literal.second.viableAsBinding()) { - potentialBindings.push_back(PrintableBinding::literalDefaultType( - literal.second.getDefaultType())); - } + potentialBindings.push_back(PrintableBinding::literalDefaultType( + literal.second.hasDefaultType() + ? literal.second.getDefaultType() + : Type(), + literal.second.viableAsBinding())); } if (potentialBindings.empty()) { out << ""; diff --git a/lib/Sema/CSOptimizer.cpp b/lib/Sema/CSOptimizer.cpp index f7ce1d1837d45..2f0aa13fb082a 100644 --- a/lib/Sema/CSOptimizer.cpp +++ b/lib/Sema/CSOptimizer.cpp @@ -382,7 +382,7 @@ static void determineBestChoicesInContext( SmallVector, 2> types; if (auto *typeVar = argType->getAs()) { - auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true); + auto bindingSet = cs.getBindingsFor(typeVar); for (const auto &binding : bindingSet.Bindings) { types.push_back({binding.BindingType, /*fromLiteral=*/false}); @@ -421,7 +421,7 @@ static void determineBestChoicesInContext( auto resultType = cs.simplifyType(argFuncType->getResult()); if (auto *typeVar = resultType->getAs()) { - auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true); + auto bindingSet = cs.getBindingsFor(typeVar); for (const auto &binding : bindingSet.Bindings) { resultTypes.push_back(binding.BindingType); diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index 1159a886f428d..1a63ee5cc37f5 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -95,6 +95,12 @@ void SplitterStep::computeFollowupSteps( // Contract the edges of the constraint graph. CG.optimize(); + if (CS.getASTContext().TypeCheckerOpts.SolverDisableSplitter) { + steps.push_back(std::make_unique( + CS, 0, &CS.InactiveConstraints, Solutions)); + return; + } + // Compute the connected components of the constraint graph. auto components = CG.computeConnectedComponents(CS.getTypeVariables()); unsigned numComponents = components.size(); diff --git a/lib/Sema/ConstraintGraph.cpp b/lib/Sema/ConstraintGraph.cpp index 4e2270e2a5359..3ccb5173829f6 100644 --- a/lib/Sema/ConstraintGraph.cpp +++ b/lib/Sema/ConstraintGraph.cpp @@ -97,7 +97,8 @@ void ConstraintGraphNode::reset() { TypeVar = nullptr; EquivalenceClass.clear(); - Bindings.reset(); + Potential.reset(); + Set.reset(); } bool ConstraintGraphNode::forRepresentativeVar() const { @@ -229,8 +230,10 @@ void ConstraintGraphNode::notifyReferencingVars( void ConstraintGraphNode::notifyReferencedVars( llvm::function_ref notification) const { - for (auto *fixedBinding : getReferencedVars()) { - notification(CG[fixedBinding]); + for (auto *referencedVar : getReferencedVars()) { + auto *repr = referencedVar->getImpl().getRepresentative(/*record=*/nullptr); + if (!repr->getImpl().getFixedType(/*record=*/nullptr)) + notification(CG[repr]); } } @@ -284,30 +287,6 @@ void ConstraintGraphNode::removeReferencedBy(TypeVariableType *typeVar) { } } -void ConstraintGraphNode::introduceToInference(Constraint *constraint) { - if (forRepresentativeVar()) { - auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr); - if (!fixedType) - getCurrentBindings().infer(CG.getConstraintSystem(), TypeVar, constraint); - } else { - auto *repr = - getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr); - CG[repr].introduceToInference(constraint); - } -} - -void ConstraintGraphNode::retractFromInference(Constraint *constraint) { - if (forRepresentativeVar()) { - auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr); - if (!fixedType) - getCurrentBindings().retract(CG.getConstraintSystem(), TypeVar,constraint); - } else { - auto *repr = - getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr); - CG[repr].retractFromInference(constraint); - } -} - void ConstraintGraphNode::updateFixedType( Type fixedType, llvm::function_refgetTypeVariables(referencedVars); for (auto *referencedVar : referencedVars) { - auto &node = CG[referencedVar]; + auto *repr = referencedVar->getImpl().getRepresentative(/*record=*/nullptr); + if (repr->getImpl().getFixedType(/*record=*/nullptr)) + continue; + + auto &node = CG[repr]; // Newly referred vars need to re-introduce all constraints associated // with this type variable since they are now going to be used in @@ -340,18 +323,20 @@ void ConstraintGraphNode::updateFixedType( } void ConstraintGraphNode::retractFromInference(Type fixedType) { + auto &cs = CG.getConstraintSystem(); return updateFixedType( fixedType, - [](ConstraintGraphNode &node, Constraint *constraint) { - node.retractFromInference(constraint); + [&cs](ConstraintGraphNode &node, Constraint *constraint) { + node.getPotentialBindings().retract(cs, node.getTypeVariable(), constraint); }); } void ConstraintGraphNode::introduceToInference(Type fixedType) { + auto &cs = CG.getConstraintSystem(); return updateFixedType( fixedType, - [](ConstraintGraphNode &node, Constraint *constraint) { - node.introduceToInference(constraint); + [&cs](ConstraintGraphNode &node, Constraint *constraint) { + node.getPotentialBindings().infer(cs, node.getTypeVariable(), constraint); }); } @@ -376,13 +361,13 @@ void ConstraintGraph::addConstraint(Constraint *constraint) { addConstraint(typeVar, constraint); - auto &node = (*this)[typeVar]; - - node.introduceToInference(constraint); + auto *repr = typeVar->getImpl().getRepresentative(/*record=*/nullptr); + if (!repr->getImpl().getFixedType(/*record=*/nullptr)) + (*this)[repr].getPotentialBindings().infer(CS, repr, constraint); if (isUsefulForReferencedVars(constraint)) { - node.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) { - referencedVar.introduceToInference(constraint); + (*this)[typeVar].notifyReferencedVars([&](ConstraintGraphNode &node) { + node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint); }); } } @@ -414,14 +399,13 @@ void ConstraintGraph::removeConstraint(Constraint *constraint) { // For the nodes corresponding to each type variable... auto referencedTypeVars = constraint->getTypeVariables(); for (auto typeVar : referencedTypeVars) { - // Find the node for this type variable. - auto &node = (*this)[typeVar]; - - node.retractFromInference(constraint); + auto *repr = typeVar->getImpl().getRepresentative(/*record=*/nullptr); + if (!repr->getImpl().getFixedType(/*record=*/nullptr)) + (*this)[repr].getPotentialBindings().retract(CS, repr, constraint); if (isUsefulForReferencedVars(constraint)) { - node.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) { - referencedVar.retractFromInference(constraint); + (*this)[typeVar].notifyReferencedVars([&](ConstraintGraphNode &node) { + node.getPotentialBindings().retract(CS, node.getTypeVariable(), constraint); }); } @@ -467,7 +451,7 @@ void ConstraintGraph::mergeNodesPre(TypeVariableType *typeVar2) { node.notifyReferencingVars( [&](ConstraintGraphNode &node, Constraint *constraint) { - node.retractFromInference(constraint); + node.getPotentialBindings().retract(CS, node.getTypeVariable(), constraint); }); } } @@ -497,19 +481,20 @@ void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1, auto &node = (*this)[newMember]; for (auto *constraint : node.getConstraints()) { - repNode.introduceToInference(constraint); + if (!typeVar1->getImpl().getFixedType(/*record=*/nullptr)) + repNode.getPotentialBindings().infer(CS, typeVar1, constraint); if (!isUsefulForReferencedVars(constraint)) continue; - repNode.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) { - referencedVar.introduceToInference(constraint); + repNode.notifyReferencedVars([&](ConstraintGraphNode &node) { + node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint); }); } node.notifyReferencingVars( [&](ConstraintGraphNode &node, Constraint *constraint) { - node.introduceToInference(constraint); + node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint); }); } } @@ -557,12 +542,12 @@ void ConstraintGraph::unrelateTypeVariables(TypeVariableType *typeVar, void ConstraintGraph::inferBindings(TypeVariableType *typeVar, Constraint *constraint) { - (*this)[typeVar].getCurrentBindings().infer(CS, typeVar, constraint); + (*this)[typeVar].getPotentialBindings().infer(CS, typeVar, constraint); } void ConstraintGraph::retractBindings(TypeVariableType *typeVar, Constraint *constraint) { - (*this)[typeVar].getCurrentBindings().retract(CS, typeVar, constraint); + (*this)[typeVar].getPotentialBindings().retract(CS, typeVar, constraint); } #pragma mark Algorithms diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index 9ff2046aa8e26..34f29ad2730d7 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -118,17 +118,17 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { cs.getConstraintLocator({})); { - auto bindings = cs.getBindingsFor(otherTy); + cs.getConstraintGraph()[otherTy].initBindingSet(); + auto &bindings = cs.getConstraintGraph()[otherTy].getBindingSet(); // Make sure that there are no direct bindings or protocol requirements. ASSERT_EQ(bindings.Bindings.size(), (unsigned)0); ASSERT_EQ(bindings.Literals.size(), (unsigned)0); - llvm::SmallDenseMap env; - env.insert({floatLiteralTy, cs.getBindingsFor(floatLiteralTy)}); + cs.getConstraintGraph()[floatLiteralTy].initBindingSet(); - bindings.finalize(env); + bindings.finalize(/*transitive=*/true); // Inferred a single transitive binding through `$T_float`. ASSERT_EQ(bindings.Bindings.size(), (unsigned)1); diff --git a/unittests/Sema/SemaFixture.cpp b/unittests/Sema/SemaFixture.cpp index 98884f24d361b..5ee9c005a5e93 100644 --- a/unittests/Sema/SemaFixture.cpp +++ b/unittests/Sema/SemaFixture.cpp @@ -126,24 +126,25 @@ ProtocolType *SemaTest::createProtocol(llvm::StringRef protocolName, BindingSet SemaTest::inferBindings(ConstraintSystem &cs, TypeVariableType *typeVar) { - llvm::SmallDenseMap cache; - for (auto *typeVar : cs.getTypeVariables()) { + auto &node = cs.getConstraintGraph()[typeVar]; + node.resetBindingSet(); + if (!typeVar->getImpl().hasRepresentativeOrFixed()) - cache.insert({typeVar, cs.getBindingsFor(typeVar, /*finalize=*/false)}); + node.initBindingSet(); } for (auto *typeVar : cs.getTypeVariables()) { - auto cachedBindings = cache.find(typeVar); - if (cachedBindings == cache.end()) + auto &node = cs.getConstraintGraph()[typeVar]; + if (!node.hasBindingSet()) continue; - auto &bindings = cachedBindings->getSecond(); - bindings.inferTransitiveProtocolRequirements(cache); - bindings.finalize(cache); + auto &bindings = node.getBindingSet(); + bindings.inferTransitiveProtocolRequirements(); + bindings.finalize(/*transitive=*/true); } - auto result = cache.find(typeVar); - assert(result != cache.end()); - return result->second; + auto &node = cs.getConstraintGraph()[typeVar]; + ASSERT(node.hasBindingSet()); + return node.getBindingSet(); } diff --git a/validation-test/Sema/type_checker_perf/fast/array_concatenation.swift b/validation-test/Sema/type_checker_perf/fast/array_concatenation.swift index 1ce9d4293dc49..873889b4a20bd 100644 --- a/validation-test/Sema/type_checker_perf/fast/array_concatenation.swift +++ b/validation-test/Sema/type_checker_perf/fast/array_concatenation.swift @@ -1,4 +1,4 @@ -// RUN: %target-typecheck-verify-swift -solver-disable-shrink +// RUN: %target-typecheck-verify-swift // Self-contained test case protocol P1 {}; func f(_: T, _: T) -> T { fatalError() } diff --git a/validation-test/Sema/type_checker_perf/fast/property_vs_unapplied_func.swift b/validation-test/Sema/type_checker_perf/fast/property_vs_unapplied_func.swift index 9cd53902f42ad..0f44ea3c31032 100644 --- a/validation-test/Sema/type_checker_perf/fast/property_vs_unapplied_func.swift +++ b/validation-test/Sema/type_checker_perf/fast/property_vs_unapplied_func.swift @@ -1,4 +1,4 @@ -// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink +// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 // REQUIRES: tools-release,no_asan struct Date { diff --git a/validation-test/Sema/type_checker_perf/slow/rdar26564101.swift b/validation-test/Sema/type_checker_perf/slow/rdar26564101.swift index 09ca79329d7a3..e9e1d32c90f05 100644 --- a/validation-test/Sema/type_checker_perf/slow/rdar26564101.swift +++ b/validation-test/Sema/type_checker_perf/slow/rdar26564101.swift @@ -1,4 +1,4 @@ -// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink +// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 // REQUIRES: tools-release,no_asan // UNSUPPORTED: swift_test_mode_optimize_none && OS=linux-gnu