diff --git a/include/swift/Sema/CSTrail.def b/include/swift/Sema/CSTrail.def new file mode 100644 index 0000000000000..cc9d82ac8a918 --- /dev/null +++ b/include/swift/Sema/CSTrail.def @@ -0,0 +1,88 @@ +//===--- CSTrail.def - Trail Change Kinds ---------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +/// +/// This file enumerates the kinds of SolverTrail::Change. +/// +//===----------------------------------------------------------------------===// + +#ifndef CHANGE +#define CHANGE(Name) +#endif + +#ifndef LOCATOR_CHANGE +#define LOCATOR_CHANGE(Name, Map) CHANGE(Name) +#endif + +#ifndef EXPR_CHANGE +#define EXPR_CHANGE(Name) CHANGE(Name) +#endif + +#ifndef CLOSURE_CHANGE +#define CLOSURE_CHANGE(Name) CHANGE(Name) +#endif + +#ifndef LAST_CHANGE +#define LAST_CHANGE(Name) +#endif + +LOCATOR_CHANGE(RecordedDisjunctionChoice, DisjunctionChoices) +LOCATOR_CHANGE(RecordedAppliedDisjunction, AppliedDisjunctions) +LOCATOR_CHANGE(RecordedMatchCallArgumentResult, argumentMatchingChoices) +LOCATOR_CHANGE(RecordedOpenedTypes, OpenedTypes) +LOCATOR_CHANGE(RecordedOpenedExistentialType, OpenedExistentialTypes) +LOCATOR_CHANGE(RecordedPackExpansionEnvironment, PackExpansionEnvironments) +LOCATOR_CHANGE(RecordedDefaultedConstraint, DefaultedConstraints) +LOCATOR_CHANGE(ResolvedOverload, ResolvedOverloads) +LOCATOR_CHANGE(RecordedImplicitValueConversion, ImplicitValueConversions) +LOCATOR_CHANGE(RecordedArgumentList, ArgumentLists) +LOCATOR_CHANGE(RecordedImplicitCallAsFunctionRoot, ImplicitCallAsFunctionRoots) +LOCATOR_CHANGE(RecordedSynthesizedConformance, SynthesizedConformances) + +EXPR_CHANGE(AppliedPropertyWrapper) +EXPR_CHANGE(RecordedImpliedResult) +EXPR_CHANGE(RecordedExprPattern) + +CLOSURE_CHANGE(RecordedClosureType) +CLOSURE_CHANGE(RecordedPreconcurrencyClosure) + +CHANGE(AddedTypeVariable) +CHANGE(AddedConstraint) +CHANGE(RemovedConstraint) +CHANGE(ExtendedEquivalenceClass) +CHANGE(RelatedTypeVariables) +CHANGE(InferredBindings) +CHANGE(RetractedBindings) +CHANGE(UpdatedTypeVariable) +CHANGE(AddedConversionRestriction) +CHANGE(AddedFix) +CHANGE(AddedFixedRequirement) +CHANGE(RecordedOpenedPackExpansionType) +CHANGE(RecordedPackEnvironment) +CHANGE(RecordedNodeType) +CHANGE(RecordedKeyPathComponentType) +CHANGE(DisabledConstraint) +CHANGE(FavoredConstraint) +CHANGE(RecordedResultBuilderTransform) +CHANGE(RecordedContextualInfo) +CHANGE(RecordedTarget) +CHANGE(RecordedCaseLabelItemInfo) +CHANGE(RecordedPotentialThrowSite) +CHANGE(RecordedIsolatedParam) +CHANGE(RecordedKeyPath) + +LAST_CHANGE(RecordedKeyPath) + +#undef LOCATOR_CHANGE +#undef EXPR_CHANGE +#undef CLOSURE_CHANGE +#undef LAST_CHANGE +#undef CHANGE \ No newline at end of file diff --git a/include/swift/Sema/CSTrail.h b/include/swift/Sema/CSTrail.h index 93ab0c93f3d7c..62488de609685 100644 --- a/include/swift/Sema/CSTrail.h +++ b/include/swift/Sema/CSTrail.h @@ -17,6 +17,9 @@ #ifndef SWIFT_SEMA_CSTRAIL_H #define SWIFT_SEMA_CSTRAIL_H +#include "swift/AST/AnyFunctionRef.h" +#include "swift/AST/Type.h" +#include "swift/AST/Types.h" #include namespace llvm { @@ -31,53 +34,18 @@ class TypeVariableType; namespace constraints { class Constraint; +struct SyntacticElementTargetKey; class SolverTrail { public: /// The kind of change made to the graph. enum class ChangeKind: unsigned { - /// Added a new vertex to the constraint graph. - AddedTypeVariable, - /// Added a new constraint to the constraint graph. - AddedConstraint, - /// Removed an existing constraint from the constraint graph. - RemovedConstraint, - /// Extended the equivalence class of a type variable in the constraint graph. - ExtendedEquivalenceClass, - /// Added a new edge in the constraint graph. - RelatedTypeVariables, - /// Inferred potential bindings from a constraint. - InferredBindings, - /// Retracted potential bindings from a constraint. - RetractedBindings, - /// Set the fixed type or parent and flags for a type variable. - UpdatedTypeVariable, - /// Recorded a conversion restriction kind. - AddedConversionRestriction, - /// Recorded a fix. - AddedFix, - /// Recorded a fixed requirement. - AddedFixedRequirement, - /// Recorded a disjunction choice. - RecordedDisjunctionChoice, - /// Recorded an applied disjunction. - RecordedAppliedDisjunction, - /// Recorded an argument matching choice. - RecordedMatchCallArgumentResult, - /// Recorded a list of opened types at a locator. - RecordedOpenedTypes, - /// Recorded the opening of an existential type at a locator. - RecordedOpenedExistentialType, - /// Recorded the opening of a pack existential type. - RecordedOpenedPackExpansionType, - /// Recorded the creation of a generic environment for a pack expansion expression. - RecordedPackExpansionEnvironment, - /// Recorded the mapping from a pack element expression to its parent - /// pack expansion expression. - RecordedPackEnvironment, - /// Record a defaulted constraint at a locator. - RecordedDefaultedConstraint, +#define CHANGE(Name) Name, +#define LAST_CHANGE(Name) Last = Name +#include "CSTrail.def" +#undef CHANGE +#undef LAST_CHANGE }; /// A change made to the constraint system. @@ -135,91 +103,131 @@ class SolverTrail { Type DstType; } Restriction; - ConstraintFix *Fix; struct { GenericTypeParamType *GP; Type ReqTy; } FixedRequirement; - ConstraintLocator *Locator; - PackExpansionType *ExpansionTy; - PackElementExpr *ElementExpr; + struct { + ASTNode Node; + Type OldType; + } Node; + + struct { + const KeyPathExpr *Expr; + Type OldType; + } KeyPath; + + ConstraintFix *TheFix; + ConstraintLocator *TheLocator; + PackExpansionType *TheExpansion; + PackElementExpr *TheElement; + Expr *TheExpr; + Stmt *TheStmt; + StmtConditionElement *TheCondElt; + Pattern *ThePattern; + PatternBindingDecl *ThePatternBinding; + VarDecl *TheVar; + AnyFunctionRef TheRef; + ClosureExpr *TheClosure; + DeclContext *TheDeclContext; + CaseLabelItem *TheItem; + CatchNode TheCatchNode; + ParamDecl *TheParam; }; Change() : Kind(ChangeKind::AddedTypeVariable), TypeVar(nullptr) { } +#define LOCATOR_CHANGE(Name, _) static Change Name(ConstraintLocator *locator); +#define EXPR_CHANGE(Name) static Change Name(Expr *expr); +#define CLOSURE_CHANGE(Name) static Change Name(ClosureExpr *closure); +#include "swift/Sema/CSTrail.def" + /// Create a change that added a type variable. - static Change addedTypeVariable(TypeVariableType *typeVar); + static Change AddedTypeVariable(TypeVariableType *typeVar); /// Create a change that added a constraint. - static Change addedConstraint(TypeVariableType *typeVar, Constraint *constraint); + static Change AddedConstraint(TypeVariableType *typeVar, Constraint *constraint); /// Create a change that removed a constraint. - static Change removedConstraint(TypeVariableType *typeVar, Constraint *constraint); + static Change RemovedConstraint(TypeVariableType *typeVar, Constraint *constraint); /// Create a change that extended an equivalence class. - static Change extendedEquivalenceClass(TypeVariableType *typeVar, + static Change ExtendedEquivalenceClass(TypeVariableType *typeVar, unsigned prevSize); /// Create a change that updated the references/referenced by sets of /// a type variable pair. - static Change relatedTypeVariables(TypeVariableType *typeVar, + static Change RelatedTypeVariables(TypeVariableType *typeVar, TypeVariableType *otherTypeVar); /// Create a change that inferred bindings from a constraint. - static Change inferredBindings(TypeVariableType *typeVar, + static Change InferredBindings(TypeVariableType *typeVar, Constraint *constraint); /// Create a change that retracted bindings from a constraint. - static Change retractedBindings(TypeVariableType *typeVar, + static Change RetractedBindings(TypeVariableType *typeVar, Constraint *constraint); /// Create a change that updated a type variable. - static Change updatedTypeVariable( + static Change UpdatedTypeVariable( TypeVariableType *typeVar, llvm::PointerUnion parentOrFixed, unsigned options); /// Create a change that recorded a restriction. - static Change addedConversionRestriction(Type srcType, Type dstType); + static Change AddedConversionRestriction(Type srcType, Type dstType); /// Create a change that recorded a fix. - static Change addedFix(ConstraintFix *fix); + static Change AddedFix(ConstraintFix *fix); /// Create a change that recorded a fixed requirement. - static Change addedFixedRequirement(GenericTypeParamType *GP, + static Change AddedFixedRequirement(GenericTypeParamType *GP, unsigned reqKind, Type requirementTy); - /// Create a change that recorded a disjunction choice. - static Change recordedDisjunctionChoice(ConstraintLocator *locator, - unsigned index); + /// Create a change that recorded the opening of a pack expansion type. + static Change RecordedOpenedPackExpansionType(PackExpansionType *expansion); - /// Create a change that recorded an applied disjunction. - static Change recordedAppliedDisjunction(ConstraintLocator *locator); + /// Create a change that recorded a mapping from a pack element expression + /// to its parent expansion expression. + static Change RecordedPackEnvironment(PackElementExpr *packElement); - /// Create a change that recorded an applied disjunction. - static Change recordedMatchCallArgumentResult(ConstraintLocator *locator); + /// Create a change that recorded an assignment of a type to an AST node. + static Change RecordedNodeType(ASTNode node, Type oldType); - /// Create a change that recorded a list of opened types. - static Change recordedOpenedTypes(ConstraintLocator *locator); + /// Create a change that recorded an assignment of a type to an AST node. + static Change RecordedKeyPathComponentType(const KeyPathExpr *expr, + unsigned component, + Type oldType); - /// Create a change that recorded the opening of an existential type. - static Change recordedOpenedExistentialType(ConstraintLocator *locator); + /// Create a change that disabled a constraint. + static Change DisabledConstraint(Constraint *constraint); - /// Create a change that recorded the opening of a pack expansion type. - static Change recordedOpenedPackExpansionType(PackExpansionType *expansion); + /// Create a change that favored a constraint. + static Change FavoredConstraint(Constraint *constraint); - /// Create a change that recorded the opening of a pack expansion type. - static Change recordedPackExpansionEnvironment(ConstraintLocator *locator); + /// Create a change that recorded a result builder transform. + static Change RecordedResultBuilderTransform(AnyFunctionRef fn); - /// Create a change that recorded a mapping from a pack element expression - /// to its parent expansion expression. - static Change recordedPackEnvironment(PackElementExpr *packElement); + /// Create a change that recorded the contextual type of an AST node. + static Change RecordedContextualInfo(ASTNode node); + + /// Create a change that recorded a SyntacticElementTarget. + static Change RecordedTarget(SyntacticElementTargetKey key); + + /// Create a change that recorded a SyntacticElementTarget. + static Change RecordedCaseLabelItemInfo(CaseLabelItem *item); - /// Create a change that recorded a defaulted constraint at a locator. - static Change recordedDefaultedConstraint(ConstraintLocator *locator); + /// Create a change that recorded a potential throw site. + static Change RecordedPotentialThrowSite(CatchNode catchNode); + + /// Create a change that recorded an isolated parameter. + static Change RecordedIsolatedParam(ParamDecl *param); + + /// Create a change that recorded a key path expression. + static Change RecordedKeyPath(KeyPathExpr *expr); /// Undo this change, reverting the constraint graph to the state it /// had prior to this change. @@ -229,9 +237,12 @@ class SolverTrail { void dump(llvm::raw_ostream &out, ConstraintSystem &cs, unsigned indent = 0) const; + + private: + SyntacticElementTargetKey getSyntacticElementTargetKey() const; }; - SolverTrail(ConstraintSystem &cs) : CS(cs) {} + SolverTrail(ConstraintSystem &cs); ~SolverTrail(); @@ -258,6 +269,8 @@ class SolverTrail { /// The list of changes made to this constraint system. std::vector Changes; + uint64_t Profile[unsigned(ChangeKind::Last) + 1]; + bool UndoActive = false; unsigned Total = 0; unsigned Max = 0; diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 0dc75fbf43742..ad8f6a24b006f 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -435,7 +435,7 @@ class TypeVariableType::Implementation { /// Record the current type-variable binding. void recordBinding(constraints::SolverTrail &trail) { - trail.recordChange(constraints::SolverTrail::Change::updatedTypeVariable( + trail.recordChange(constraints::SolverTrail::Change::UpdatedTypeVariable( getTypeVariable(), ParentOrFixed, getRawOptions())); } @@ -1203,8 +1203,7 @@ struct CaseLabelItemInfo { /// Key to the constraint solver's mapping from AST nodes to their corresponding /// target. -class SyntacticElementTargetKey { -public: +struct SyntacticElementTargetKey { enum class Kind { empty, tombstone, @@ -1218,72 +1217,75 @@ class SyntacticElementTargetKey { functionRef, }; -private: Kind kind; union { - const StmtConditionElement *stmtCondElement; + StmtConditionElement *stmtCondElement; - const Expr *expr; + Expr *expr; - const Stmt *stmt; + Stmt *stmt; - const Pattern *pattern; + Pattern *pattern; struct PatternBindingEntry { - const PatternBindingDecl *patternBinding; + PatternBindingDecl *patternBinding; unsigned index; } patternBindingEntry; - const VarDecl *varDecl; + VarDecl *varDecl; - const DeclContext *functionRef; + DeclContext *functionRef; } storage; -public: SyntacticElementTargetKey(Kind kind) { assert(kind == Kind::empty || kind == Kind::tombstone); this->kind = kind; } - SyntacticElementTargetKey(const StmtConditionElement *stmtCondElement) { + SyntacticElementTargetKey(StmtConditionElement *stmtCondElement) { kind = Kind::stmtCondElement; storage.stmtCondElement = stmtCondElement; } - SyntacticElementTargetKey(const Expr *expr) { + SyntacticElementTargetKey(Expr *expr) { kind = Kind::expr; storage.expr = expr; } - SyntacticElementTargetKey(const ClosureExpr *closure) { + SyntacticElementTargetKey(ClosureExpr *closure) { kind = Kind::closure; storage.expr = closure; } - SyntacticElementTargetKey(const Stmt *stmt) { + SyntacticElementTargetKey(Stmt *stmt) { kind = Kind::stmt; storage.stmt = stmt; } - SyntacticElementTargetKey(const Pattern *pattern) { + SyntacticElementTargetKey(Pattern *pattern) { kind = Kind::pattern; storage.pattern = pattern; } - SyntacticElementTargetKey(const PatternBindingDecl *patternBinding, + SyntacticElementTargetKey(PatternBindingDecl *patternBinding, unsigned index) { kind = Kind::patternBindingEntry; storage.patternBindingEntry.patternBinding = patternBinding; storage.patternBindingEntry.index = index; } - SyntacticElementTargetKey(const VarDecl *varDecl) { + SyntacticElementTargetKey(VarDecl *varDecl) { kind = Kind::varDecl; storage.varDecl = varDecl; } - SyntacticElementTargetKey(const AnyFunctionRef functionRef) { + SyntacticElementTargetKey(DeclContext *dc) { + kind = Kind::functionRef; + storage.functionRef = dc; + } + + SyntacticElementTargetKey(AnyFunctionRef functionRef) { kind = Kind::functionRef; storage.functionRef = functionRef.getAsDeclContext(); } @@ -1515,7 +1517,7 @@ class Solution { /// Maps expressions for implied results (e.g implicit 'then' statements, /// implicit 'return' statements in single expression body closures) to their /// result kind. - llvm::MapVector ImpliedResults; + llvm::DenseMap ImpliedResults; /// For locators associated with call expressions, the trailing closure /// matching rule and parameter bindings that were applied. @@ -1565,51 +1567,50 @@ class Solution { /// The key path expression and its root type, value type, and decl context /// introduced by this solution. - llvm::MapVector> + llvm::DenseMap> KeyPaths; /// Contextual types introduced by this solution. std::vector> contextualTypes; /// Maps AST nodes to their target. - llvm::MapVector targets; + llvm::DenseMap targets; /// Maps case label items to information tracked about them as they are /// being solved. - llvm::MapVector - caseLabelItems; + llvm::DenseMap caseLabelItems; /// Maps catch nodes to the set of potential throw sites that will be caught /// at that location. - /// The set of opened types for a given locator. + /// Keep track of all of the potential throw sites. std::vector> potentialThrowSites; /// A map of expressions to the ExprPatterns that they are being solved as /// a part of. - llvm::MapVector exprPatterns; + llvm::DenseMap exprPatterns; /// The set of parameters that have been inferred to be 'isolated'. - std::vector isolatedParams; + llvm::DenseSet isolatedParams; /// The set of closures that have been inferred to be "isolated by /// preconcurrency". - std::vector preconcurrencyClosures; + llvm::DenseSet preconcurrencyClosures; /// The set of functions that have been transformed by a result builder. llvm::MapVector resultBuilderTransformed; /// A map from argument expressions to their applied property wrapper expressions. - llvm::MapVector> appliedPropertyWrappers; + llvm::DenseMap> appliedPropertyWrappers; /// A mapping from the constraint locators for references to various /// names (e.g., member references, normal name references, possible /// constructions) to the argument lists for the call to that locator. - llvm::MapVector argumentLists; + llvm::DenseMap argumentLists; /// The set of implicitly generated `.callAsFunction` root expressions. llvm::DenseMap @@ -1617,7 +1618,7 @@ class Solution { /// The set of conformances synthesized during solving (i.e. for /// ad-hoc distributed `SerializationRequirement` conformances). - llvm::MapVector + llvm::DenseMap SynthesizedConformances; /// Record a new argument matching choice for given locator that maps a @@ -1727,12 +1728,7 @@ class Solution { } std::optional - getTargetFor(SyntacticElementTargetKey key) const { - auto known = targets.find(key); - if (known == targets.end()) - return std::nullopt; - return known->second; - } + getTargetFor(SyntacticElementTargetKey key) const; ConstraintLocator *getCalleeLocator(ConstraintLocator *locator, bool lookThroughApply = true) const; @@ -2207,7 +2203,7 @@ class ConstraintSystem { llvm::FoldingSetVector ConstraintLocators; /// The overload sets that have been resolved along the current path. - llvm::MapVector ResolvedOverloads; + llvm::DenseMap ResolvedOverloads; /// The current fixed score for this constraint system and the (partial) /// solution it represents. @@ -2221,6 +2217,9 @@ class ConstraintSystem { /// Maps discovered closures to their types inferred /// from declared parameters/result and body. + /// + /// This is a MapVector because contractEdges() iterates over it and + /// may depend on order. llvm::MapVector ClosureTypes; /// Maps closures and local functions to the pack expansion expressions they @@ -2230,7 +2229,7 @@ class ConstraintSystem { /// Maps expressions for implied results (e.g implicit 'then' statements, /// implicit 'return' statements in single expression body closures) to their /// result kind. - llvm::MapVector ImpliedResults; + llvm::DenseMap ImpliedResults; /// This is a *global* list of all result builder bodies that have /// been determined to be incorrect by failing constraint generation. @@ -2256,11 +2255,7 @@ class ConstraintSystem { /// nodes themselves. This allows us to typecheck and /// run through various diagnostics passes without actually mutating /// the types on the nodes. - llvm::MapVector NodeTypes; - - /// The nodes for which we have produced types, along with the prior type - /// each node had before introducing this type. - llvm::SmallVector, 8> addedNodeTypes; + llvm::DenseMap NodeTypes; /// Maps components in a key path expression to their type. Needed because /// KeyPathExpr + Index isn't an \c ASTNode and thus can't be stored in \c @@ -2269,29 +2264,23 @@ class ConstraintSystem { Type> KeyPathComponentTypes; - /// Same as \c addedNodeTypes for \c KeyPathComponentTypes. - llvm::SmallVector< - std::tuple> - addedKeyPathComponentTypes; - /// Maps a key path root, value, and decl context to the key path expression. - llvm::MapVector> + llvm::DenseMap> KeyPaths; /// Maps AST entries to their targets. - llvm::MapVector targets; + llvm::DenseMap targets; /// Contextual type information for expressions that are part of this /// constraint system. The second type, if valid, contains the type as it /// should appear in actual constraint. This will have unbound generic types /// opened, placeholder types converted to type variables, etc. - llvm::MapVector> contextualTypes; + llvm::DenseMap> contextualTypes; /// Information about each case label item tracked by the constraint system. - llvm::SmallMapVector - caseLabelItems; + llvm::SmallDenseMap caseLabelItems; /// Keep track of all of the potential throw sites. /// FIXME: This data structure should be replaced with something that @@ -2300,14 +2289,14 @@ class ConstraintSystem { /// A map of expressions to the ExprPatterns that they are being solved as /// a part of. - llvm::SmallMapVector exprPatterns; + llvm::SmallDenseMap exprPatterns; /// The set of parameters that have been inferred to be 'isolated'. - llvm::SmallSetVector isolatedParams; + llvm::SmallDenseSet isolatedParams; /// The set of closures that have been inferred to be "isolated by /// preconcurrency". - llvm::SmallSetVector preconcurrencyClosures; + llvm::SmallDenseSet preconcurrencyClosures; /// Maps closure parameters to type variables. llvm::DenseMap @@ -2342,7 +2331,7 @@ class ConstraintSystem { /// The set of implicit value conversions performed by the solver on /// a current path to reach a solution. - llvm::SmallMapVector + llvm::SmallDenseMap ImplicitValueConversions; /// The worklist of "active" constraints that should be revisited @@ -2408,11 +2397,11 @@ class ConstraintSystem { /// A mapping from the constraint locators for references to various /// names (e.g., member references, normal name references, possible /// constructions) to the argument lists for the call to that locator. - llvm::MapVector ArgumentLists; + llvm::DenseMap ArgumentLists; public: /// A map from argument expressions to their applied property wrapper expressions. - llvm::SmallMapVector, 4> + llvm::SmallDenseMap, 4> appliedPropertyWrappers; /// The locators of \c Defaultable constraints whose defaults were used. @@ -2421,17 +2410,11 @@ class ConstraintSystem { void recordDefaultedConstraint(ConstraintLocator *locator) { bool inserted = DefaultedConstraints.insert(locator).second; if (inserted) { - if (isRecordingChanges()) { - recordChange(SolverTrail::Change::recordedDefaultedConstraint(locator)); - } + if (solverState) + recordChange(SolverTrail::Change::RecordedDefaultedConstraint(locator)); } } - void removeDefaultedConstraint(ConstraintLocator *locator) { - bool erased = DefaultedConstraints.erase(locator); - ASSERT(erased); - } - /// A cache that stores the @dynamicCallable required methods implemented by /// types. llvm::DenseMap @@ -2440,12 +2423,12 @@ class ConstraintSystem { /// A cache of implicitly generated dot-member expressions used as roots /// for some `.callAsFunction` calls. The key here is "base" locator for /// the `.callAsFunction` member reference. - llvm::SmallMapVector + llvm::SmallDenseMap ImplicitCallAsFunctionRoots; /// The set of conformances synthesized during solving (i.e. for /// ad-hoc distributed `SerializationRequirement` conformances). - llvm::MapVector + llvm::DenseMap SynthesizedConformances; private: @@ -2640,24 +2623,6 @@ class ConstraintSystem { } generatedConstraints.erase(genStart, genEnd); - - for (unsigned constraintIdx : - range(scope->numDisabledConstraints, disabledConstraints.size())) { - if (disabledConstraints[constraintIdx]->isDisabled()) - disabledConstraints[constraintIdx]->setEnabled(); - } - disabledConstraints.erase( - disabledConstraints.begin() + scope->numDisabledConstraints, - disabledConstraints.end()); - - for (unsigned constraintIdx : - range(scope->numFavoredConstraints, favoredConstraints.size())) { - if (favoredConstraints[constraintIdx]->isFavored()) - favoredConstraints[constraintIdx]->setFavored(false); - } - favoredConstraints.erase( - favoredConstraints.begin() + scope->numFavoredConstraints, - favoredConstraints.end()); } /// Check whether constraint system is allowed to form solutions @@ -2666,19 +2631,12 @@ class ConstraintSystem { return AllowFreeTypeVariables != FreeTypeVariableBinding::Disallow; } - unsigned getNumDisabledConstraints() const { - return disabledConstraints.size(); - } - /// Disable the given constraint; this change will be rolled back /// when we exit the current solver scope. void disableConstraint(Constraint *constraint) { + ASSERT(!constraint->isDisabled()); constraint->setDisabled(); - disabledConstraints.push_back(constraint); - } - - unsigned getNumFavoredConstraints() const { - return favoredConstraints.size(); + Trail.recordChange(SolverTrail::Change::DisabledConstraint(constraint)); } /// Favor the given constraint; this change will be rolled back @@ -2687,7 +2645,7 @@ class ConstraintSystem { assert(!constraint->isFavored()); constraint->setFavored(); - favoredConstraints.push_back(constraint); + Trail.recordChange(SolverTrail::Change::FavoredConstraint(constraint)); } private: @@ -2711,9 +2669,7 @@ class ConstraintSystem { llvm::SmallVector< std::tuple, 4> scopes; - SmallVector disabledConstraints; - SmallVector favoredConstraints; - + /// Depth of the solution stack. unsigned depth = 0; }; @@ -2845,6 +2801,11 @@ class ConstraintSystem { /// Associate an argument list with a call at a given locator. void associateArgumentList(ConstraintLocator *locator, ArgumentList *args); + /// Same as associateArgumentList() except the locator points at the + /// argument list itself. Records a change in the trail. + void recordArgumentList(ConstraintLocator *locator, + ArgumentList *args); + /// If the given node is a function expression with a parent ApplyExpr, /// returns the apply, otherwise returns the node itself. ASTNode includingParentApply(ASTNode node); @@ -2894,64 +2855,6 @@ class ConstraintSystem { /// FIXME: Remove this. unsigned numFixes; - unsigned numAddedNodeTypes; - - unsigned numAddedKeyPathComponentTypes; - - unsigned numDisabledConstraints; - - unsigned numFavoredConstraints; - - unsigned numResultBuilderTransformed; - - /// The length of \c appliedPropertyWrappers - unsigned numAppliedPropertyWrappers; - - /// The length of \c ResolvedOverloads. - unsigned numResolvedOverloads; - - /// The length of \c ClosureTypes. - unsigned numInferredClosureTypes; - - /// The length of \c ImpliedResults. - unsigned numImpliedResults; - - /// The length of \c contextualTypes. - unsigned numContextualTypes; - - /// The length of \c targets. - unsigned numTargets; - - /// The length of \c caseLabelItems. - unsigned numCaseLabelItems; - - /// The length of \c potentialThrowSites. - unsigned numPotentialThrowSites; - - /// The length of \c exprPatterns. - unsigned numExprPatterns; - - /// The length of \c isolatedParams. - unsigned numIsolatedParams; - - /// The length of \c PreconcurrencyClosures. - unsigned numPreconcurrencyClosures; - - /// The length of \c ImplicitValueConversions. - unsigned numImplicitValueConversions; - - /// The length of \c KeyPaths. - unsigned numKeyPaths; - - /// The length of \c ArgumentLists. - unsigned numArgumentLists; - - /// The length of \c ImplicitCallAsFunctionRoots. - unsigned numImplicitCallAsFunctionRoots; - - /// The length of \c SynthesizedConformances. - unsigned numSynthesizedConformances; - /// The previous score. Score PreviousScore; @@ -3118,16 +3021,24 @@ class ConstraintSystem { } /// Record an implied result for a ReturnStmt or ThenStmt. - void recordImpliedResult(const Expr *E, ImpliedResultKind kind) { - assert(E); + void recordImpliedResult(Expr *E, ImpliedResultKind kind) { + ASSERT(E); auto inserted = ImpliedResults.insert({E, kind}).second; - assert(inserted && "Duplicate implied result?"); - (void)inserted; + ASSERT(inserted && "Duplicate implied result?"); + + if (solverState) + recordChange(SolverTrail::Change::RecordedImpliedResult(E)); + } + + /// Undo the above change. + void removeImpliedResult(Expr *E) { + bool erased = ImpliedResults.erase(E); + ASSERT(erased); } /// Whether the given expression is the implied result for either a ReturnStmt /// or ThenStmt, and if so, the kind of implied result. - std::optional isImpliedResult(const Expr *E) const { + std::optional isImpliedResult(Expr *E) const { auto result = ImpliedResults.find(E); if (result == ImpliedResults.end()) return std::nullopt; @@ -3136,10 +3047,20 @@ class ConstraintSystem { } void setClosureType(const ClosureExpr *closure, FunctionType *type) { - assert(closure); - assert(type && "Expected non-null type"); - assert(ClosureTypes.count(closure) == 0 && "Cannot reset closure type"); - ClosureTypes.insert({closure, type}); + ASSERT(closure); + ASSERT(type); + bool inserted = ClosureTypes.insert({closure, type}).second; + ASSERT(inserted); + + if (solverState) { + recordChange(SolverTrail::Change::RecordedClosureType( + const_cast(closure))); + } + } + + void removeClosureType(const ClosureExpr *closure) { + bool erased = ClosureTypes.erase(closure); + ASSERT(erased); } FunctionType *getClosureType(const ClosureExpr *closure) const { @@ -3205,21 +3126,45 @@ class ConstraintSystem { this->FavoredTypes[E] = T; } - /// Set the type in our type map for the given node. + /// Set the type in our type map for the given node, and record the change + /// in the trail. /// - /// The side tables are used through the expression type checker to avoid mutating nodes until - /// we know we have successfully type-checked them. + /// The side tables are used through the expression type checker to avoid + /// mutating nodes until we know we have successfully type-checked them. void setType(ASTNode node, Type type) { - assert(!node.isNull() && "Cannot set type information on null node"); - assert(type && "Expected non-null type"); + ASSERT(!node.isNull() && "Cannot set type information on null node"); + ASSERT(type && "Expected non-null type"); // Record the type. Type &entry = NodeTypes[node]; Type oldType = entry; entry = type; - // Record the fact that we ascribed a type to this node. - addedNodeTypes.push_back({node, oldType}); + if (oldType.getPointer() != type.getPointer()) { + // Record the fact that we assigned a type to this node. + if (solverState) + recordChange(SolverTrail::Change::RecordedNodeType(node, oldType)); + } + } + + /// Undo the above change. + void restoreType(ASTNode node, Type oldType) { + ASSERT(!node.isNull() && "Cannot set type information on null node"); + + if (oldType) { + auto found = NodeTypes.find(node); + ASSERT(found != NodeTypes.end()); + found->second = oldType; + } else { + bool erased = NodeTypes.erase(node); + ASSERT(erased); + } + } + + /// Check to see if we have a type for a node. + bool hasType(ASTNode node) const { + assert(!node.isNull() && "Expected non-null node"); + return NodeTypes.count(node) > 0; } /// Set the type in our type map for a given expression. The side @@ -3227,20 +3172,33 @@ class ConstraintSystem { /// avoid mutating expressions until we know we have successfully /// type-checked them. void setType(const KeyPathExpr *KP, unsigned I, Type T) { - assert(KP && "Expected non-null key path parameter!"); - assert(T && "Expected non-null type!"); + ASSERT(KP && "Expected non-null key path parameter!"); + ASSERT(T && "Expected non-null type!"); Type &entry = KeyPathComponentTypes[{KP, I}]; Type oldType = entry; entry = T; - addedKeyPathComponentTypes.push_back(std::make_tuple(KP, I, oldType)); + if (oldType.getPointer() != T.getPointer()) { + if (solverState) { + recordChange( + SolverTrail::Change::RecordedKeyPathComponentType( + KP, I, oldType)); + } + } } - /// Check to see if we have a type for a node. - bool hasType(ASTNode node) const { - assert(!node.isNull() && "Expected non-null node"); - return NodeTypes.count(node) > 0; + void restoreType(const KeyPathExpr *KP, unsigned I, Type T) { + ASSERT(KP && "Expected non-null key path parameter!"); + + if (T) { + auto found = KeyPathComponentTypes.find({KP, I}); + ASSERT(found != KeyPathComponentTypes.end()); + found->second = T; + } else { + bool erased = KeyPathComponentTypes.erase({KP, I}); + ASSERT(erased); + } } bool hasType(const KeyPathExpr *KP, unsigned I) const { @@ -3297,10 +3255,17 @@ class ConstraintSystem { } void setContextualInfo(ASTNode node, ContextualTypeInfo info) { - assert(bool(node) && "Expected non-null expression!"); - assert(contextualTypes.count(node) == 0 && - "Already set this contextual type"); - contextualTypes[node] = {info, Type()}; + ASSERT(bool(node) && "Expected non-null expression!"); + bool inserted = contextualTypes.insert({node, {info, Type()}}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedContextualInfo(node)); + } + + void removeContextualInfo(ASTNode node) { + bool erased = contextualTypes.erase(node); + ASSERT(erased); } std::optional getContextualTypeInfo(ASTNode node) const { @@ -3358,18 +3323,12 @@ class ConstraintSystem { } void setTargetFor(SyntacticElementTargetKey key, - SyntacticElementTarget target) { - assert(targets.count(key) == 0 && "Already set this target"); - targets.insert({key, target}); - } + SyntacticElementTarget target); + + void removeTargetFor(SyntacticElementTargetKey key); std::optional - getTargetFor(SyntacticElementTargetKey key) const { - auto known = targets.find(key); - if (known == targets.end()) - return std::nullopt; - return known->second; - } + getTargetFor(SyntacticElementTargetKey key) const; std::optional getAppliedResultBuilderTransform(AnyFunctionRef fn) const { @@ -3401,18 +3360,36 @@ class ConstraintSystem { } void setCaseLabelItemInfo(const CaseLabelItem *item, CaseLabelItemInfo info) { - assert(item != nullptr); - assert(caseLabelItems.count(item) == 0); - caseLabelItems[item] = info; + ASSERT(item); + bool inserted = caseLabelItems.insert({item, info}).second; + ASSERT(inserted); + + if (solverState) { + recordChange(SolverTrail::Change::RecordedCaseLabelItemInfo( + const_cast(item))); + } + } + + void removeCaseLabelItemInfo(const CaseLabelItem *item) { + bool erased = caseLabelItems.erase(item); + ASSERT(erased); } /// Record a given ExprPattern as the parent of its sub-expression. void setExprPatternFor(Expr *E, ExprPattern *EP) { - assert(E); - assert(EP); - auto inserted = exprPatterns.insert({E, EP}).second; - assert(inserted && "Mapping already defined?"); - (void)inserted; + ASSERT(E); + ASSERT(EP); + bool inserted = exprPatterns.insert({E, EP}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedExprPattern(E)); + } + + /// Record a given ExprPattern as the parent of its sub-expression. + void removeExprPatternFor(Expr *E) { + bool erased = exprPatterns.erase(E); + ASSERT(erased); } std::optional @@ -3429,6 +3406,14 @@ class ConstraintSystem { PotentialThrowSite::Kind kind, Type type, ConstraintLocatorBuilder locator); + /// Used by the above to update potentialThrowSites and record a change + /// in the trail. + void recordPotentialThrowSite(CatchNode catchNode, + PotentialThrowSite site); + + /// Undo the above change. + void removePotentialThrowSite(CatchNode catchNode); + /// Determine the caught error type for the given catch node. Type getCaughtErrorType(CatchNode node); @@ -3453,12 +3438,6 @@ class ConstraintSystem { void recordOpenedExistentialType(ConstraintLocator *locator, OpenedArchetypeType *opened); - /// Undo the above change. - void removeOpenedExistentialType(ConstraintLocator *locator) { - bool erased = OpenedExistentialTypes.erase(locator); - ASSERT(erased); - } - /// Get the opened element generic environment for the given locator. GenericEnvironment *getPackElementEnvironment(ConstraintLocator *locator, CanType shapeClass); @@ -3467,12 +3446,6 @@ class ConstraintSystem { void recordPackExpansionEnvironment(ConstraintLocator *locator, std::pair uuidAndShape); - /// Undo the above change. - void removePackExpansionEnvironment(ConstraintLocator *locator) { - bool erased = PackExpansionEnvironments.erase(locator); - ASSERT(erased); - } - /// Get the opened element generic environment for the given pack element. PackExpansionExpr *getPackEnvironment(PackElementExpr *packElement) const; @@ -3655,22 +3628,17 @@ class ConstraintSystem { void recordMatchCallArgumentResult(ConstraintLocator *locator, MatchCallArgumentResult result); - /// Undo the above change. - void removeMatchCallArgumentResult(ConstraintLocator *locator) { - bool erased = argumentMatchingChoices.erase(locator); - ASSERT(erased); - } - - /// Record implicitly generated `callAsFunction` with root at the - /// given expression, located at \c locator. - void recordCallAsFunction(UnresolvedDotExpr *root, ArgumentList *arguments, - ConstraintLocator *locator); + void recordImplicitCallAsFunctionRoot( + ConstraintLocator *locator, UnresolvedDotExpr *root); /// Record root, value, and declContext of keypath expression for use across - /// constraint system. - void recordKeyPath(KeyPathExpr *keypath, TypeVariableType *root, + /// constraint system, and add a change to the trail. + void recordKeyPath(const KeyPathExpr *keypath, TypeVariableType *root, TypeVariableType *value, DeclContext *dc); + /// Undo the above change. + void removeKeyPath(const KeyPathExpr *keypath); + /// Walk a closure AST to determine its effects. /// /// \returns a function's extended info describing the effects, as @@ -4209,6 +4177,20 @@ class ConstraintSystem { bool resolveClosure(TypeVariableType *typeVar, Type contextualType, ConstraintLocatorBuilder locator); + /// Used by the above to update isolatedParams and record a change in + /// the trail. + void recordIsolatedParam(ParamDecl *param); + + /// Undo the above change. + void removeIsolatedParam(ParamDecl *param); + + /// Used by the above to update preconcurrencyClosures and record a change in + /// the trail. + void recordPreconcurrencyClosure(const ClosureExpr *closure); + + /// Undo the above change. + void removePreconcurrencyClosure(const ClosureExpr *closure); + /// Given the fact that contextual type is now available for the type /// variable representing a pack expansion type, let's resolve the expansion. /// @@ -4424,13 +4406,11 @@ class ConstraintSystem { void recordOpenedType( ConstraintLocator *locator, ArrayRef openedTypes); - /// Undo the above change. - void removeOpenedType(ConstraintLocator *locator); - /// Record the set of opened types for the given locator. void recordOpenedTypes( ConstraintLocatorBuilder locator, - const OpenedTypeMap &replacements); + const OpenedTypeMap &replacements, + bool fixmeAllowDuplicates=false); /// Check whether the given type conforms to the given protocol and if /// so return a valid conformance reference. @@ -4944,6 +4924,9 @@ class ConstraintSystem { buildDisjunctionForOptionalVsUnderlying(boundTy, type, dynamicLocator); } + void recordResolvedOverload(ConstraintLocator *locator, + SelectedOverload choice); + /// Resolve the given overload set to the given choice. void resolveOverload(ConstraintLocator *locator, Type boundType, OverloadChoice choice, DeclContext *useDC); @@ -5078,6 +5061,9 @@ class ConstraintSystem { ConstraintLocatorBuilder locator, TypeMatchOptions flags); + void recordSynthesizedConformance(ConstraintLocator *locator, + ProtocolConformanceRef conformance); + /// Attempt to simplify the given conformance constraint. /// /// \param type The type being tested. @@ -5233,6 +5219,10 @@ class ConstraintSystem { TypeMatchOptions flags, ConstraintLocatorBuilder locator); + /// Update ImplicitValueConversions and record a change in the trail. + void recordImplicitValueConversion(ConstraintLocator *locator, + ConversionRestrictionKind restriction); + /// Simplify a conversion constraint by applying the given /// reduction rule, which is known to apply at the outermost level. SolutionKind simplifyRestrictedConstraint( @@ -5324,12 +5314,27 @@ class ConstraintSystem { ConstraintKind bodyResultConstraintKind, Type contextualType, ConstraintLocatorBuilder locator); + /// Used by matchResultBuilder() to update resultBuilderTransformed and record + /// a change in the trail. + void recordResultBuilderTransform(AnyFunctionRef fn, + AppliedBuilderTransform transformInfo); + + /// Undo the above change. + void removeResultBuilderTransform(AnyFunctionRef fn); + /// Matches a wrapped or projected value parameter type to its backing /// property wrapper type by applying the property wrapper. TypeMatchResult applyPropertyWrapperToParameter( Type wrapperType, Type paramType, ParamDecl *param, Identifier argLabel, ConstraintKind matchKind, ConstraintLocatorBuilder locator); + /// Used by applyPropertyWrapperToParameter() to update appliedPropertyWrappers + /// and record a change in the trail. + void applyPropertyWrapper(Expr *anchor, AppliedPropertyWrapper applied); + + /// Undo the above change. + void removePropertyWrapper(Expr *anchor); + /// Determine whether given type variable with its set of bindings is viable /// to be attempted on the next step of the solver. std::optional determineBestBindings( @@ -5358,22 +5363,10 @@ class ConstraintSystem { /// Record a particular disjunction choice and add a change to the trail. void recordDisjunctionChoice(ConstraintLocator *locator, unsigned index); - /// Undo the above change. - void removeDisjunctionChoice(ConstraintLocator *locator) { - bool erased = DisjunctionChoices.erase(locator); - ASSERT(erased); - } - /// Record applied disjunction and add a change to the trail. void recordAppliedDisjunction(ConstraintLocator *locator, FunctionType *type); - /// Undo the above change. - void removeAppliedDisjunction(ConstraintLocator *locator) { - bool erased = AppliedDisjunctions.erase(locator); - ASSERT(erased); - } - /// Filter the set of disjunction terms, keeping only those where the /// predicate returns \c true. /// @@ -5686,7 +5679,7 @@ class ConstraintSystem { } /// The overload sets that have already been resolved along the current path. - const llvm::MapVector & + const llvm::DenseMap & getResolvedOverloads() const { return ResolvedOverloads; } diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index 9793d5c1c2dff..658602b4610f4 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -1211,15 +1211,7 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType, transformInfo.transformedBody = transformedBody->second; // Record the transformation. - assert( - std::find_if( - resultBuilderTransformed.begin(), resultBuilderTransformed.end(), - [&](const std::pair &elt) { - return elt.first == fn; - }) == resultBuilderTransformed.end() && - "already transformed this body along this path!?!"); - resultBuilderTransformed.insert( - std::make_pair(fn, std::move(transformInfo))); + recordResultBuilderTransform(fn, std::move(transformInfo)); if (generateConstraints(fn, transformInfo.transformedBody)) return getTypeMatchFailure(locator); @@ -1227,6 +1219,22 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType, return getTypeMatchSuccess(); } +void ConstraintSystem::recordResultBuilderTransform(AnyFunctionRef fn, + AppliedBuilderTransform transformInfo) { + bool inserted = resultBuilderTransformed.insert( + std::make_pair(fn, std::move(transformInfo))).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedResultBuilderTransform(fn)); +} + +/// Undo the above change. +void ConstraintSystem::removeResultBuilderTransform(AnyFunctionRef fn) { + bool erased = resultBuilderTransformed.erase(fn); + ASSERT(erased); +} + namespace { class ReturnStmtFinder : public ASTWalker { std::vector ReturnStmts; diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index ecfd76aea94be..646283676864f 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -8539,11 +8539,9 @@ bool ExprRewriter::isDistributedThunk(ConcreteDeclRef ref, Expr *context) { // If this is a method reference on an potentially isolated // actor then it cannot be a remote thunk. bool isPotentiallyIsolated = isPotentiallyIsolatedActor( - actor, - [&](ParamDecl *P) { - return P->isIsolated() || - llvm::is_contained(solution.isolatedParams, P); - }); + actor, [&](ParamDecl *P) { + return P->isIsolated() || solution.isolatedParams.count(P); + }); // Adjust the declaration context to the innermost context that is neither // a local function nor a closure, so that the actor reference is checked diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 2f9a971ab60ec..97075fd87dc20 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -21,9 +21,12 @@ #include "swift/Sema/ConstraintGraph.h" #include "swift/Sema/ConstraintSystem.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include +#define DEBUG_TYPE "PotentialBindings" + using namespace swift; using namespace constraints; using namespace inference; @@ -1774,7 +1777,7 @@ void PotentialBindings::infer(Constraint *constraint) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::inferredBindings(TypeVar, constraint)); + CS.recordChange(SolverTrail::Change::InferredBindings(TypeVar, constraint)); switch (constraint->getKind()) { case ConstraintKind::Bind: @@ -1949,7 +1952,14 @@ void PotentialBindings::retract(Constraint *constraint) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::retractedBindings(TypeVar, constraint)); + CS.recordChange(SolverTrail::Change::RetractedBindings(TypeVar, constraint)); + + LLVM_DEBUG( + llvm::dbgs() << Constraints.size() << " " << Bindings.size() << " " + << Protocols.size() << " " << Literals.size() << " " + << AdjacentVars.size() << " " << DelayedBy.size() << " " + << SubtypeOf.size() << " " << SupertypeOf.size() << " " + << EquivalentTo.size() << "\n"); Bindings.erase( llvm::remove_if(Bindings, diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index ddad265d565c2..f7936a4049c61 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4995,7 +4995,7 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, } Type boolTy = boolDecl->getDeclaredInterfaceType(); - for (const auto &condElement : condition) { + for (auto &condElement : condition) { switch (condElement.getKind()) { case StmtConditionElement::CK_Availability: // Nothing to do here. @@ -5046,6 +5046,22 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, return false; } +void ConstraintSystem::applyPropertyWrapper( + Expr *anchor, AppliedPropertyWrapper applied) { + appliedPropertyWrappers[anchor].push_back(applied); + + if (solverState) + recordChange(SolverTrail::Change::AppliedPropertyWrapper(anchor)); +} + +void ConstraintSystem::removePropertyWrapper(Expr *anchor) { + auto found = appliedPropertyWrappers.find(anchor); + ASSERT(found != appliedPropertyWrappers.end()); + auto &wrappers = found->second; + ASSERT(!wrappers.empty()); + wrappers.pop_back(); +} + ConstraintSystem::TypeMatchResult ConstraintSystem::applyPropertyWrapperToParameter( Type wrapperType, Type paramType, ParamDecl *param, Identifier argLabel, @@ -5079,13 +5095,13 @@ ConstraintSystem::applyPropertyWrapperToParameter( setType(param->getPropertyWrapperProjectionVar(), projectionType); } - appliedPropertyWrappers[anchor].push_back({ wrapperType, PropertyWrapperInitKind::ProjectedValue }); + applyPropertyWrapper(anchor, { wrapperType, PropertyWrapperInitKind::ProjectedValue }); } else if (param->hasExternalPropertyWrapper()) { Type wrappedValueType = computeWrappedValueType(param, wrapperType); addConstraint(matchKind, paramType, wrappedValueType, locator); setType(param->getPropertyWrapperWrappedValueVar(), wrappedValueType); - appliedPropertyWrappers[anchor].push_back({ wrapperType, PropertyWrapperInitKind::WrappedValue }); + applyPropertyWrapper(anchor, { wrapperType, PropertyWrapperInitKind::WrappedValue }); } else { return getTypeMatchFailure(locator); } diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 9b36ba1a58621..0ef228869f576 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -8329,6 +8329,16 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( return matchExistentialTypes(type, protocol, kind, flags, locator); } +void ConstraintSystem::recordSynthesizedConformance( + ConstraintLocator *locator, + ProtocolConformanceRef conformance) { + bool inserted = SynthesizedConformances.insert({locator, conformance}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedSynthesizedConformance(locator)); +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( Type type, ProtocolDecl *protocol, @@ -8487,7 +8497,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( ProtocolConformanceRef synthesized(protocol); auto witnessLoc = getConstraintLocator( locator.getAnchor(), LocatorPathElt::Witness(witness)); - SynthesizedConformances.insert({witnessLoc, synthesized}); + // FIXME: Why are we recording the same locator more than once here? + if (SynthesizedConformances.count(witnessLoc) == 0) + recordSynthesizedConformance(witnessLoc, synthesized); return recordConformance(synthesized); }; @@ -11513,6 +11525,36 @@ static Type getOpenedResultBuilderTypeFor(ConstraintSystem &cs, return builderType; } +void ConstraintSystem::recordIsolatedParam(ParamDecl *param) { + bool inserted = isolatedParams.insert(param).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedIsolatedParam(param)); +} + +void ConstraintSystem::removeIsolatedParam(ParamDecl *param) { + bool erased = isolatedParams.erase(param); + ASSERT(erased); +} + +void ConstraintSystem::recordPreconcurrencyClosure( + const ClosureExpr *closure) { + bool inserted = preconcurrencyClosures.insert(closure).second; + ASSERT(inserted); + + if (solverState) { + recordChange(SolverTrail::Change::RecordedPreconcurrencyClosure( + const_cast(closure))); + } +} + +void ConstraintSystem::removePreconcurrencyClosure( + const ClosureExpr *closure) { + bool erased = preconcurrencyClosures.erase(closure); + ASSERT(erased); +} + bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar, Type contextualType, ConstraintLocatorBuilder locator) { @@ -11522,7 +11564,7 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar, // Note if this closure is isolated by preconcurrency. if (hasPreconcurrencyCallee(locator)) - preconcurrencyClosures.insert(closure); + recordPreconcurrencyClosure(closure); // Let's look through all optionals associated with contextual // type to make it possible to infer parameter/result type of @@ -11658,7 +11700,7 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar, // Note when a parameter is inferred to be isolated. if (contextualParam->isIsolated() && !flags.isIsolated() && paramDecl) - isolatedParams.insert(paramDecl); + recordIsolatedParam(paramDecl); // Carry-over the ownership specifier from the contextual parameter. auto paramOwnership = @@ -12869,7 +12911,11 @@ createImplicitRootForCallAsFunction(ConstraintSystem &cs, Type refType, // Record a type of the new reference in the constraint system. cs.setType(implicitRef, refType); // Record new `.callAsFunction` in the constraint system. - cs.recordCallAsFunction(implicitRef, arguments, calleeLocator); + cs.recordImplicitCallAsFunctionRoot(calleeLocator, implicitRef); + + auto *implicitRefLocator = cs.getConstraintLocator( + implicitRef, ConstraintLocator::ApplyArgument); + cs.associateArgumentList(implicitRefLocator, arguments); } return implicitRef; @@ -14052,6 +14098,17 @@ void ConstraintSystem::addRestrictedConstraint( TMF_GenerateConstraints, locator); } +void ConstraintSystem::recordImplicitValueConversion( + ConstraintLocator *locator, + ConversionRestrictionKind restriction) { + bool inserted = ImplicitValueConversions.insert( + {getConstraintLocator(locator), restriction}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedImplicitValueConversion(locator)); +} + /// Given that we have a conversion constraint between two types, and /// that the given constraint-reduction rule applies between them at /// the top level, apply it and generate any necessary recursive @@ -14624,7 +14681,7 @@ ConstraintSystem::simplifyRestrictedConstraintImpl( getASTContext(), {Argument(SourceLoc(), Identifier(), nullptr)}, /*firstTrailingClosureIndex=*/std::nullopt, AllocationArena::ConstraintSolver); - ArgumentLists.insert({argumentsLoc, argList}); + recordArgumentList(argumentsLoc, argList); } auto *memberTypeLoc = getConstraintLocator( @@ -14927,25 +14984,37 @@ void ConstraintSystem::recordMatchCallArgumentResult( ConstraintLocator *locator, MatchCallArgumentResult result) { assert(locator->isLastElement()); bool inserted = argumentMatchingChoices.insert({locator, result}).second; - if (inserted) { - if (isRecordingChanges()) - recordChange(SolverTrail::Change::recordedMatchCallArgumentResult(locator)); - } + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedMatchCallArgumentResult(locator)); } -void ConstraintSystem::recordCallAsFunction(UnresolvedDotExpr *root, - ArgumentList *arguments, - ConstraintLocator *locator) { - ImplicitCallAsFunctionRoots.insert({locator, root}); +void ConstraintSystem::recordImplicitCallAsFunctionRoot( + ConstraintLocator *locator, UnresolvedDotExpr *root) { + bool inserted = ImplicitCallAsFunctionRoots.insert({locator, root}).second; + ASSERT(inserted); - associateArgumentList( - getConstraintLocator(root, ConstraintLocator::ApplyArgument), arguments); + if (solverState) + recordChange(SolverTrail::Change::RecordedImplicitCallAsFunctionRoot(locator)); } -void ConstraintSystem::recordKeyPath(KeyPathExpr *keypath, +void ConstraintSystem::recordKeyPath(const KeyPathExpr *keypath, TypeVariableType *root, TypeVariableType *value, DeclContext *dc) { - KeyPaths.insert(std::make_pair(keypath, std::make_tuple(root, value, dc))); + bool inserted = KeyPaths.insert( + std::make_pair(keypath, std::make_tuple(root, value, dc))).second; + ASSERT(inserted); + + if (solverState) { + recordChange(SolverTrail::Change::RecordedKeyPath( + const_cast(keypath))); + } +} + +void ConstraintSystem::removeKeyPath(const KeyPathExpr *keypath) { + bool erased = KeyPaths.erase(keypath); + ASSERT(erased); } ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint( diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 2a7c2a8d3d2ab..6859c41fc25c7 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -218,7 +218,8 @@ Solution ConstraintSystem::finalize() { solution.contextualTypes.push_back({entry.first, entry.second.first}); } - solution.targets = targets; + for (auto &target : targets) + solution.targets.insert(target); for (const auto &item : caseLabelItems) solution.caseLabelItems.insert(item); @@ -230,10 +231,10 @@ Solution ConstraintSystem::finalize() { solution.exprPatterns.insert(pattern); for (const auto ¶m : isolatedParams) - solution.isolatedParams.push_back(param); + solution.isolatedParams.insert(param); - for (const auto &closure : preconcurrencyClosures) - solution.preconcurrencyClosures.push_back(closure); + for (auto closure : preconcurrencyClosures) + solution.preconcurrencyClosures.insert(closure); for (const auto &transformed : resultBuilderTransformed) { solution.resultBuilderTransformed.insert(transformed); @@ -290,8 +291,10 @@ void ConstraintSystem::applySolution(const Solution &solution) { // Register overload choices. // FIXME: Copy these directly into some kind of partial solution? - for (auto overload : solution.overloadChoices) - ResolvedOverloads.insert(overload); + for (auto overload : solution.overloadChoices) { + if (!ResolvedOverloads.count(overload.first)) + recordResolvedOverload(overload.first, overload.second); + } // Register constraint restrictions. // FIXME: Copy these directly into some kind of partial solution? @@ -304,51 +307,64 @@ void ConstraintSystem::applySolution(const Solution &solution) { // Register the solution's disjunction choices. for (auto &choice : solution.DisjunctionChoices) { - recordDisjunctionChoice(choice.first, choice.second); + if (DisjunctionChoices.count(choice.first) == 0) + recordDisjunctionChoice(choice.first, choice.second); } // Register the solution's applied disjunctions. for (auto &choice : solution.AppliedDisjunctions) { - recordAppliedDisjunction(choice.first, choice.second); + if (AppliedDisjunctions.count(choice.first) == 0) + recordAppliedDisjunction(choice.first, choice.second); } // Remember all of the argument/parameter matching choices we made. for (auto &argumentMatch : solution.argumentMatchingChoices) { - recordMatchCallArgumentResult(argumentMatch.first, argumentMatch.second); + if (argumentMatchingChoices.count(argumentMatch.first) == 0) + recordMatchCallArgumentResult(argumentMatch.first, argumentMatch.second); } // Remember implied results. - for (auto impliedResult : solution.ImpliedResults) - ImpliedResults.insert(impliedResult); + for (auto impliedResult : solution.ImpliedResults) { + if (ImpliedResults.count(impliedResult.first) == 0) + recordImpliedResult(impliedResult.first, impliedResult.second); + } // Register the solution's opened types. for (const auto &opened : solution.OpenedTypes) { - recordOpenedType(opened.first, opened.second); + if (OpenedTypes.count(opened.first) == 0) + recordOpenedType(opened.first, opened.second); } // Register the solution's opened existential types. for (const auto &openedExistential : solution.OpenedExistentialTypes) { - recordOpenedExistentialType(openedExistential.first, openedExistential.second); + if (OpenedExistentialTypes.count(openedExistential.first) == 0) { + recordOpenedExistentialType(openedExistential.first, + openedExistential.second); + } } // Register the solution's opened pack expansion types. for (const auto &expansion : solution.OpenedPackExpansionTypes) { - recordOpenedPackExpansionType(expansion.first, expansion.second); + if (OpenedPackExpansionTypes.count(expansion.first) == 0) + recordOpenedPackExpansionType(expansion.first, expansion.second); } // Register the solutions's pack expansion environments. for (const auto &expansion : solution.PackExpansionEnvironments) { - recordPackExpansionEnvironment(expansion.first, expansion.second); + if (PackExpansionEnvironments.count(expansion.first) == 0) + recordPackExpansionEnvironment(expansion.first, expansion.second); } // Register the solutions's pack environments. for (auto &packEnvironment : solution.PackEnvironments) { - addPackEnvironment(packEnvironment.first, packEnvironment.second); + if (PackEnvironments.count(packEnvironment.first) == 0) + addPackEnvironment(packEnvironment.first, packEnvironment.second); } // Register the defaulted type variables. - for (auto *locator : solution.DefaultedConstraints) + for (auto *locator : solution.DefaultedConstraints) { recordDefaultedConstraint(locator); + } // Add the node types back. for (auto &nodeType : solution.nodeTypes) { @@ -362,7 +378,12 @@ void ConstraintSystem::applySolution(const Solution &solution) { // Add key paths. for (const auto &keypath : solution.KeyPaths) { - KeyPaths.insert(keypath); + if (KeyPaths.count(keypath.first) == 0) { + recordKeyPath(keypath.first, + std::get<0>(keypath.second), + std::get<1>(keypath.second), + std::get<2>(keypath.second)); + } } // Add the contextual types. @@ -383,53 +404,78 @@ void ConstraintSystem::applySolution(const Solution &solution) { setCaseLabelItemInfo(info.first, info.second); } - potentialThrowSites.insert(potentialThrowSites.end(), - solution.potentialThrowSites.begin(), - solution.potentialThrowSites.end()); + auto sites = ArrayRef(solution.potentialThrowSites); + ASSERT(sites.size() >= potentialThrowSites.size()); + for (const auto &site : sites.slice(potentialThrowSites.size())) { + potentialThrowSites.push_back(site); + } for (auto param : solution.isolatedParams) { - isolatedParams.insert(param); + if (isolatedParams.count(param) == 0) + recordIsolatedParam(param); } - for (auto &pair : solution.exprPatterns) - exprPatterns.insert(pair); + for (auto &pair : solution.exprPatterns) { + if (exprPatterns.count(pair.first) == 0) + setExprPatternFor(pair.first, pair.second); + } for (auto closure : solution.preconcurrencyClosures) { - preconcurrencyClosures.insert(closure); + if (preconcurrencyClosures.count(closure) == 0) + recordPreconcurrencyClosure(closure); } for (const auto &transformed : solution.resultBuilderTransformed) { - resultBuilderTransformed.insert(transformed); + if (resultBuilderTransformed.count(transformed.first) == 0) + recordResultBuilderTransform(transformed.first, transformed.second); } for (const auto &appliedWrapper : solution.appliedPropertyWrappers) { - appliedPropertyWrappers.insert(appliedWrapper); + auto found = appliedPropertyWrappers.find(appliedWrapper.first); + if (found == appliedPropertyWrappers.end()) { + appliedPropertyWrappers.insert(appliedWrapper); + } else { + auto &existing = found->second; + ASSERT(existing.size() <= appliedWrapper.second.size()); + existing = appliedWrapper.second; + } } for (auto &valueConversion : solution.ImplicitValueConversions) { - ImplicitValueConversions.insert(valueConversion); + if (ImplicitValueConversions.count(valueConversion.first) == 0) { + recordImplicitValueConversion(valueConversion.first, + valueConversion.second); + } } // Register the argument lists. for (auto &argListMapping : solution.argumentLists) { - ArgumentLists.insert(argListMapping); + if (ArgumentLists.count(argListMapping.first) == 0) + recordArgumentList(argListMapping.first, argListMapping.second); } for (auto &implicitRoot : solution.ImplicitCallAsFunctionRoots) { - ImplicitCallAsFunctionRoots.insert(implicitRoot); + if (ImplicitCallAsFunctionRoots.count(implicitRoot.first) == 0) + recordImplicitCallAsFunctionRoot(implicitRoot.first, implicitRoot.second); } for (auto &synthesized : solution.SynthesizedConformances) { - SynthesizedConformances.insert(synthesized); + if (SynthesizedConformances.count(synthesized.first) == 0) + recordSynthesizedConformance(synthesized.first, synthesized.second); } // Register any fixes produced along this path. - for (auto *fix : solution.Fixes) - addFix(fix); + for (auto *fix : solution.Fixes) { + if (Fixes.count(fix) == 0) + addFix(fix); + } // Register fixed requirements. - for (auto fix : solution.FixedRequirements) - recordFixedRequirement(std::get<0>(fix), std::get<1>(fix), std::get<2>(fix)); + for (auto fix : solution.FixedRequirements) { + recordFixedRequirement(std::get<0>(fix), + std::get<1>(fix), + std::get<2>(fix)); + } } bool ConstraintSystem::simplify() { // While we have a constraint in the worklist, process it. @@ -663,27 +709,6 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs) numTypeVariables = cs.TypeVariables.size(); numFixes = cs.Fixes.size(); - numAddedNodeTypes = cs.addedNodeTypes.size(); - numAddedKeyPathComponentTypes = cs.addedKeyPathComponentTypes.size(); - numKeyPaths = cs.KeyPaths.size(); - numDisabledConstraints = cs.solverState->getNumDisabledConstraints(); - numFavoredConstraints = cs.solverState->getNumFavoredConstraints(); - numResultBuilderTransformed = cs.resultBuilderTransformed.size(); - numAppliedPropertyWrappers = cs.appliedPropertyWrappers.size(); - numResolvedOverloads = cs.ResolvedOverloads.size(); - numInferredClosureTypes = cs.ClosureTypes.size(); - numImpliedResults = cs.ImpliedResults.size(); - numContextualTypes = cs.contextualTypes.size(); - numTargets = cs.targets.size(); - numCaseLabelItems = cs.caseLabelItems.size(); - numPotentialThrowSites = cs.potentialThrowSites.size(); - numExprPatterns = cs.exprPatterns.size(); - numIsolatedParams = cs.isolatedParams.size(); - numPreconcurrencyClosures = cs.preconcurrencyClosures.size(); - numImplicitValueConversions = cs.ImplicitValueConversions.size(); - numArgumentLists = cs.ArgumentLists.size(); - numImplicitCallAsFunctionRoots = cs.ImplicitCallAsFunctionRoots.size(); - numSynthesizedConformances = cs.SynthesizedConformances.size(); PreviousScore = cs.CurrentScore; @@ -699,8 +724,6 @@ ConstraintSystem::SolverScope::~SolverScope() { // Erase the end of various lists. truncate(cs.TypeVariables, numTypeVariables); - truncate(cs.ResolvedOverloads, numResolvedOverloads); - // Move any remaining active constraints into the inactive list. if (!cs.ActiveConstraints.empty()) { for (auto &constraint : cs.ActiveConstraints) { @@ -718,79 +741,6 @@ ConstraintSystem::SolverScope::~SolverScope() { // constraints introduced by the current scope. cs.solverState->rollback(this); - // Remove any node types we registered. - for (unsigned i : - reverse(range(numAddedNodeTypes, cs.addedNodeTypes.size()))) { - auto node = cs.addedNodeTypes[i].first; - if (Type oldType = cs.addedNodeTypes[i].second) - cs.NodeTypes[node] = oldType; - else - cs.NodeTypes.erase(node); - } - truncate(cs.addedNodeTypes, numAddedNodeTypes); - - // Remove any node types we registered. - for (unsigned i : reverse(range(numAddedKeyPathComponentTypes, - cs.addedKeyPathComponentTypes.size()))) { - auto KeyPath = std::get<0>(cs.addedKeyPathComponentTypes[i]); - auto KeyPathIndex = std::get<1>(cs.addedKeyPathComponentTypes[i]); - if (Type oldType = std::get<2>(cs.addedKeyPathComponentTypes[i])) { - cs.KeyPathComponentTypes[{KeyPath, KeyPathIndex}] = oldType; - } else { - cs.KeyPathComponentTypes.erase({KeyPath, KeyPathIndex}); - } - } - truncate(cs.addedKeyPathComponentTypes, numAddedKeyPathComponentTypes); - - /// Remove any key path expressions. - truncate(cs.KeyPaths, numKeyPaths); - - /// Remove any builder transformed closures. - truncate(cs.resultBuilderTransformed, numResultBuilderTransformed); - - // Remove any applied property wrappers. - truncate(cs.appliedPropertyWrappers, numAppliedPropertyWrappers); - - // Remove any inferred closure types (e.g. used in result builder body). - truncate(cs.ClosureTypes, numInferredClosureTypes); - - // Remove any implied results. - truncate(cs.ImpliedResults, numImpliedResults); - - // Remove any contextual types. - truncate(cs.contextualTypes, numContextualTypes); - - // Remove any targets. - truncate(cs.targets, numTargets); - - // Remove any case label item infos. - truncate(cs.caseLabelItems, numCaseLabelItems); - - // Remove any potential throw sites. - truncate(cs.potentialThrowSites, numPotentialThrowSites); - - // Remove any ExprPattern mappings. - truncate(cs.exprPatterns, numExprPatterns); - - // Remove any isolated parameters. - truncate(cs.isolatedParams, numIsolatedParams); - - // Remove any preconcurrency closures. - truncate(cs.preconcurrencyClosures, numPreconcurrencyClosures); - - // Remove any implicit value conversions. - truncate(cs.ImplicitValueConversions, numImplicitValueConversions); - - // Remove any argument lists no longer in scope. - truncate(cs.ArgumentLists, numArgumentLists); - - // Remove any implicitly generated root expressions for `.callAsFunction` - // which are no longer in scope. - truncate(cs.ImplicitCallAsFunctionRoots, numImplicitCallAsFunctionRoots); - - // Remove any implicitly synthesized conformances. - truncate(cs.SynthesizedConformances, numSynthesizedConformances); - // Reset the previous score. cs.CurrentScore = PreviousScore; @@ -1794,23 +1744,29 @@ ConstraintSystem::filterDisjunction( llvm::errs().indent(indent) << ")\n"; } - if (restoreOnFail) - constraintsToRestoreOnFail.push_back(constraint); - - if (solverState) - solverState->disableConstraint(constraint); - else - constraint->setDisabled(); + if (!constraint->isDisabled()) { + if (restoreOnFail) + constraintsToRestoreOnFail.push_back(constraint); + else if (solverState) + solverState->disableConstraint(constraint); + else + constraint->setDisabled(); + } } - switch (numEnabledTerms) { - case 0: + if (numEnabledTerms == 0) + return SolutionKind::Error; + + if (restoreOnFail) { for (auto constraint : constraintsToRestoreOnFail) { - constraint->setEnabled(); + if (solverState) + solverState->disableConstraint(constraint); + else + constraint->setDisabled(); } - return SolutionKind::Error; + } - case 1: { + if (numEnabledTerms == 1) { // Only a single constraint remains. Retire the disjunction and make // the remaining constraint active. auto choice = disjunction->getNestedConstraints()[choiceIdx]; @@ -1859,9 +1815,7 @@ ConstraintSystem::filterDisjunction( return failedConstraint ? SolutionKind::Unsolved : SolutionKind::Solved; } - default: - return SolutionKind::Unsolved; - } + return SolutionKind::Unsolved; } // Attempt to find a disjunction of bind constraints where all options diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 850ff03e5a7ad..10e87eeab0517 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -26,6 +26,36 @@ using namespace swift; using namespace swift::constraints; +void ConstraintSystem::setTargetFor(SyntacticElementTargetKey key, + SyntacticElementTarget target) { + bool inserted = targets.insert({key, target}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedTarget(key)); +} + +void ConstraintSystem::removeTargetFor(SyntacticElementTargetKey key) { + bool erased = targets.erase(key); + ASSERT(erased); +} + +std::optional +ConstraintSystem::getTargetFor(SyntacticElementTargetKey key) const { + auto known = targets.find(key); + if (known == targets.end()) + return std::nullopt; + return known->second; +} + +std::optional +Solution::getTargetFor(SyntacticElementTargetKey key) const { + auto known = targets.find(key); + if (known == targets.end()) + return std::nullopt; + return known->second; +} + namespace { // Produce an implicit empty tuple expression. @@ -2551,7 +2581,7 @@ bool ConstraintSystem::applySolution(AnyFunctionRef fn, param->setIsolated(true); } - if (llvm::is_contained(solution.preconcurrencyClosures, closure)) + if (solution.preconcurrencyClosures.count(closure)) closure->setIsolatedByPreconcurrency(); // Coerce the result type, if it was written explicitly. diff --git a/lib/Sema/CSTrail.cpp b/lib/Sema/CSTrail.cpp index 8f34347e78691..6f9ea91605125 100644 --- a/lib/Sema/CSTrail.cpp +++ b/lib/Sema/CSTrail.cpp @@ -33,6 +33,13 @@ using namespace constraints; #define DEBUG_TYPE "SolverTrail" +SolverTrail::SolverTrail(ConstraintSystem &cs) + : CS(cs) { + for (unsigned i = 0; i <= unsigned(ChangeKind::Last); ++i) { + Profile[i] = 0; + } +} + SolverTrail::~SolverTrail() { // If constraint system is in an invalid state, it's // possible that constraint graph is corrupted as well @@ -41,8 +48,34 @@ SolverTrail::~SolverTrail() { ASSERT(Changes.empty() && "Trail corrupted"); } +#define LOCATOR_CHANGE(Name, _) \ + SolverTrail::Change \ + SolverTrail::Change::Name(ConstraintLocator *locator) { \ + Change result; \ + result.Kind = ChangeKind::Name; \ + result.TheLocator = locator; \ + return result; \ + } +#define EXPR_CHANGE(Name) \ + SolverTrail::Change \ + SolverTrail::Change::Name(Expr *expr) { \ + Change result; \ + result.Kind = ChangeKind::Name; \ + result.TheExpr = expr; \ + return result; \ + } +#define CLOSURE_CHANGE(Name) \ + SolverTrail::Change \ + SolverTrail::Change::Name(ClosureExpr *closure) { \ + Change result; \ + result.Kind = ChangeKind::Name; \ + result.TheClosure = closure; \ + return result; \ + } +#include "swift/Sema/CSTrail.def" + SolverTrail::Change -SolverTrail::Change::addedTypeVariable(TypeVariableType *typeVar) { +SolverTrail::Change::AddedTypeVariable(TypeVariableType *typeVar) { Change result; result.Kind = ChangeKind::AddedTypeVariable; result.TypeVar = typeVar; @@ -50,7 +83,7 @@ SolverTrail::Change::addedTypeVariable(TypeVariableType *typeVar) { } SolverTrail::Change -SolverTrail::Change::addedConstraint(TypeVariableType *typeVar, +SolverTrail::Change::AddedConstraint(TypeVariableType *typeVar, Constraint *constraint) { Change result; result.Kind = ChangeKind::AddedConstraint; @@ -60,7 +93,7 @@ SolverTrail::Change::addedConstraint(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::removedConstraint(TypeVariableType *typeVar, +SolverTrail::Change::RemovedConstraint(TypeVariableType *typeVar, Constraint *constraint) { Change result; result.Kind = ChangeKind::RemovedConstraint; @@ -70,7 +103,7 @@ SolverTrail::Change::removedConstraint(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::extendedEquivalenceClass(TypeVariableType *typeVar, +SolverTrail::Change::ExtendedEquivalenceClass(TypeVariableType *typeVar, unsigned prevSize) { Change result; result.Kind = ChangeKind::ExtendedEquivalenceClass; @@ -80,7 +113,7 @@ SolverTrail::Change::extendedEquivalenceClass(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::relatedTypeVariables(TypeVariableType *typeVar, +SolverTrail::Change::RelatedTypeVariables(TypeVariableType *typeVar, TypeVariableType *otherTypeVar) { Change result; result.Kind = ChangeKind::RelatedTypeVariables; @@ -90,7 +123,7 @@ SolverTrail::Change::relatedTypeVariables(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::inferredBindings(TypeVariableType *typeVar, +SolverTrail::Change::InferredBindings(TypeVariableType *typeVar, Constraint *constraint) { Change result; result.Kind = ChangeKind::InferredBindings; @@ -100,7 +133,7 @@ SolverTrail::Change::inferredBindings(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::retractedBindings(TypeVariableType *typeVar, +SolverTrail::Change::RetractedBindings(TypeVariableType *typeVar, Constraint *constraint) { Change result; result.Kind = ChangeKind::RetractedBindings; @@ -110,7 +143,7 @@ SolverTrail::Change::retractedBindings(TypeVariableType *typeVar, } SolverTrail::Change -SolverTrail::Change::updatedTypeVariable( +SolverTrail::Change::UpdatedTypeVariable( TypeVariableType *typeVar, llvm::PointerUnion parentOrFixed, unsigned options) { @@ -123,7 +156,7 @@ SolverTrail::Change::updatedTypeVariable( } SolverTrail::Change -SolverTrail::Change::addedConversionRestriction(Type srcType, Type dstType) { +SolverTrail::Change::AddedConversionRestriction(Type srcType, Type dstType) { Change result; result.Kind = ChangeKind::AddedConversionRestriction; result.Restriction.SrcType = srcType; @@ -132,15 +165,15 @@ SolverTrail::Change::addedConversionRestriction(Type srcType, Type dstType) { } SolverTrail::Change -SolverTrail::Change::addedFix(ConstraintFix *fix) { +SolverTrail::Change::AddedFix(ConstraintFix *fix) { Change result; result.Kind = ChangeKind::AddedFix; - result.Fix = fix; + result.TheFix = fix; return result; } SolverTrail::Change -SolverTrail::Change::addedFixedRequirement(GenericTypeParamType *GP, +SolverTrail::Change::AddedFixedRequirement(GenericTypeParamType *GP, unsigned reqKind, Type reqTy) { Change result; @@ -152,83 +185,187 @@ SolverTrail::Change::addedFixedRequirement(GenericTypeParamType *GP, } SolverTrail::Change -SolverTrail::Change::recordedDisjunctionChoice(ConstraintLocator *locator, - unsigned index) { +SolverTrail::Change::RecordedOpenedPackExpansionType(PackExpansionType *expansionTy) { Change result; - result.Kind = ChangeKind::RecordedDisjunctionChoice; - result.Locator = locator; - result.Options = index; + result.Kind = ChangeKind::RecordedOpenedPackExpansionType; + result.TheExpansion = expansionTy; return result; } SolverTrail::Change -SolverTrail::Change::recordedAppliedDisjunction(ConstraintLocator *locator) { +SolverTrail::Change::RecordedPackEnvironment(PackElementExpr *packElement) { Change result; - result.Kind = ChangeKind::RecordedAppliedDisjunction; - result.Locator = locator; + result.Kind = ChangeKind::RecordedPackEnvironment; + result.TheElement = packElement; return result; } SolverTrail::Change -SolverTrail::Change::recordedMatchCallArgumentResult(ConstraintLocator *locator) { +SolverTrail::Change::RecordedNodeType(ASTNode node, Type oldType) { Change result; - result.Kind = ChangeKind::RecordedMatchCallArgumentResult; - result.Locator = locator; + result.Kind = ChangeKind::RecordedNodeType; + result.Node.Node = node; + result.Node.OldType = oldType; return result; } SolverTrail::Change -SolverTrail::Change::recordedOpenedTypes(ConstraintLocator *locator) { +SolverTrail::Change::RecordedKeyPathComponentType(const KeyPathExpr *expr, + unsigned component, + Type oldType) { Change result; - result.Kind = ChangeKind::RecordedOpenedTypes; - result.Locator = locator; + result.Kind = ChangeKind::RecordedKeyPathComponentType; + result.Options = component; + result.KeyPath.Expr = expr; + result.KeyPath.OldType = oldType; return result; } SolverTrail::Change -SolverTrail::Change::recordedOpenedExistentialType(ConstraintLocator *locator) { +SolverTrail::Change::DisabledConstraint(Constraint *constraint) { Change result; - result.Kind = ChangeKind::RecordedOpenedExistentialType; - result.Locator = locator; + result.Kind = ChangeKind::DisabledConstraint; + result.TheConstraint.Constraint = constraint; return result; } SolverTrail::Change -SolverTrail::Change::recordedOpenedPackExpansionType(PackExpansionType *expansionTy) { +SolverTrail::Change::FavoredConstraint(Constraint *constraint) { Change result; - result.Kind = ChangeKind::RecordedOpenedPackExpansionType; - result.ExpansionTy = expansionTy; + result.Kind = ChangeKind::FavoredConstraint; + result.TheConstraint.Constraint = constraint; return result; } SolverTrail::Change -SolverTrail::Change::recordedPackExpansionEnvironment(ConstraintLocator *locator) { +SolverTrail::Change::RecordedResultBuilderTransform(AnyFunctionRef fn) { Change result; - result.Kind = ChangeKind::RecordedPackExpansionEnvironment; - result.Locator = locator; + result.Kind = ChangeKind::RecordedResultBuilderTransform; + result.TheRef = fn; return result; } SolverTrail::Change -SolverTrail::Change::recordedPackEnvironment(PackElementExpr *packElement) { +SolverTrail::Change::RecordedContextualInfo(ASTNode node) { Change result; - result.Kind = ChangeKind::RecordedPackEnvironment; - result.ElementExpr = packElement; + result.Kind = ChangeKind::RecordedContextualInfo; + result.Node.Node = node; + return result; +} + +SolverTrail::Change +SolverTrail::Change::RecordedTarget(SyntacticElementTargetKey key) { + Change result; + result.Kind = ChangeKind::RecordedTarget; + result.Options = unsigned(key.kind); + + switch (key.kind) { + case SyntacticElementTargetKey::Kind::empty: + case SyntacticElementTargetKey::Kind::tombstone: + llvm_unreachable("Invalid SyntacticElementTargetKey::Kind"); + case SyntacticElementTargetKey::Kind::stmtCondElement: + result.TheCondElt = key.storage.stmtCondElement; + break; + case SyntacticElementTargetKey::Kind::expr: + result.TheExpr = key.storage.expr; + break; + case SyntacticElementTargetKey::Kind::closure: + result.TheClosure = cast(key.storage.expr); + break; + case SyntacticElementTargetKey::Kind::stmt: + result.TheStmt = key.storage.stmt; + break; + case SyntacticElementTargetKey::Kind::pattern: + result.ThePattern = key.storage.pattern; + break; + case SyntacticElementTargetKey::Kind::patternBindingEntry: + result.ThePatternBinding = key.storage.patternBindingEntry.patternBinding; + result.Options |= key.storage.patternBindingEntry.index << 8; + break; + case SyntacticElementTargetKey::Kind::varDecl: + result.TheVar = key.storage.varDecl; + break; + case SyntacticElementTargetKey::Kind::functionRef: + result.TheDeclContext = key.storage.functionRef; + break; + } + + return result; +} + +SolverTrail::Change +SolverTrail::Change::RecordedCaseLabelItemInfo(CaseLabelItem *item) { + Change result; + result.Kind = ChangeKind::RecordedCaseLabelItemInfo; + result.TheItem = item; + return result; +} + +SolverTrail::Change +SolverTrail::Change::RecordedPotentialThrowSite(CatchNode catchNode) { + Change result; + result.Kind = ChangeKind::RecordedPotentialThrowSite; + result.TheCatchNode = catchNode; + return result; +} + +SolverTrail::Change +SolverTrail::Change::RecordedIsolatedParam(ParamDecl *param) { + Change result; + result.Kind = ChangeKind::RecordedIsolatedParam; + result.TheParam = param; return result; } SolverTrail::Change -SolverTrail::Change::recordedDefaultedConstraint(ConstraintLocator *locator) { +SolverTrail::Change::RecordedKeyPath(KeyPathExpr *expr) { Change result; - result.Kind = ChangeKind::RecordedDefaultedConstraint; - result.Locator = locator; + result.Kind = ChangeKind::RecordedKeyPath; + result.KeyPath.Expr = expr; return result; } +SyntacticElementTargetKey +SolverTrail::Change::getSyntacticElementTargetKey() const { + ASSERT(Kind == ChangeKind::RecordedTarget); + + auto kind = SyntacticElementTargetKey::Kind(Options & 0xff); + + switch (kind) { + case SyntacticElementTargetKey::Kind::empty: + case SyntacticElementTargetKey::Kind::tombstone: + llvm_unreachable("Invalid SyntacticElementTargetKey::Kind"); + case SyntacticElementTargetKey::Kind::stmtCondElement: + return SyntacticElementTargetKey(TheCondElt); + case SyntacticElementTargetKey::Kind::expr: + return SyntacticElementTargetKey(TheExpr); + case SyntacticElementTargetKey::Kind::closure: + return SyntacticElementTargetKey(TheClosure); + case SyntacticElementTargetKey::Kind::stmt: + return SyntacticElementTargetKey(TheStmt); + case SyntacticElementTargetKey::Kind::pattern: + return SyntacticElementTargetKey(ThePattern); + case SyntacticElementTargetKey::Kind::patternBindingEntry: + return SyntacticElementTargetKey(ThePatternBinding, Options >> 8); + case SyntacticElementTargetKey::Kind::varDecl: + return SyntacticElementTargetKey(TheVar); + case SyntacticElementTargetKey::Kind::functionRef: + return SyntacticElementTargetKey(TheDeclContext); + } +} + void SolverTrail::Change::undo(ConstraintSystem &cs) const { auto &cg = cs.getConstraintGraph(); switch (Kind) { +#define LOCATOR_CHANGE(Name, Map) \ + case ChangeKind::Name: { \ + bool erased = cs.Map.erase(TheLocator); \ + ASSERT(erased); \ + break; \ + } +#include "swift/Sema/CSTrail.def" + case ChangeKind::AddedTypeVariable: cg.removeNode(TypeVar); break; @@ -270,7 +407,7 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { break; case ChangeKind::AddedFix: - cs.removeFix(Fix); + cs.removeFix(TheFix); break; case ChangeKind::AddedFixedRequirement: @@ -278,40 +415,77 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { FixedRequirement.ReqTy); break; - case ChangeKind::RecordedDisjunctionChoice: - cs.removeDisjunctionChoice(Locator); + case ChangeKind::RecordedOpenedPackExpansionType: + cs.removeOpenedPackExpansionType(TheExpansion); break; - case ChangeKind::RecordedAppliedDisjunction: - cs.removeAppliedDisjunction(Locator); + case ChangeKind::RecordedPackEnvironment: + cs.removePackEnvironment(TheElement); break; - case ChangeKind::RecordedMatchCallArgumentResult: - cs.removeMatchCallArgumentResult(Locator); + case ChangeKind::RecordedNodeType: + cs.restoreType(Node.Node, Node.OldType); break; - case ChangeKind::RecordedOpenedTypes: - cs.removeOpenedType(Locator); + case ChangeKind::RecordedKeyPathComponentType: + cs.restoreType(KeyPath.Expr, Options, KeyPath.OldType); break; - case ChangeKind::RecordedOpenedExistentialType: - cs.removeOpenedExistentialType(Locator); + case ChangeKind::DisabledConstraint: + TheConstraint.Constraint->setEnabled(); break; - case ChangeKind::RecordedOpenedPackExpansionType: - cs.removeOpenedPackExpansionType(ExpansionTy); + case ChangeKind::FavoredConstraint: + ASSERT(TheConstraint.Constraint->isFavored()); + TheConstraint.Constraint->setFavored(false); break; - case ChangeKind::RecordedPackExpansionEnvironment: - cs.removePackExpansionEnvironment(Locator); + case ChangeKind::RecordedResultBuilderTransform: + cs.removeResultBuilderTransform(TheRef); break; - case ChangeKind::RecordedPackEnvironment: - cs.removePackEnvironment(ElementExpr); + case ChangeKind::AppliedPropertyWrapper: + cs.removePropertyWrapper(TheExpr); + break; + + case ChangeKind::RecordedClosureType: + cs.removeClosureType(TheClosure); + break; + + case ChangeKind::RecordedImpliedResult: + cs.removeImpliedResult(TheExpr); + break; + + case ChangeKind::RecordedContextualInfo: + cs.removeContextualInfo(Node.Node); + break; + + case ChangeKind::RecordedTarget: + cs.removeTargetFor(getSyntacticElementTargetKey()); + break; + + case ChangeKind::RecordedCaseLabelItemInfo: + cs.removeCaseLabelItemInfo(TheItem); + break; + + case ChangeKind::RecordedPotentialThrowSite: + cs.removePotentialThrowSite(TheCatchNode); + break; + + case ChangeKind::RecordedExprPattern: + cs.removeExprPatternFor(TheExpr); + break; + + case ChangeKind::RecordedIsolatedParam: + cs.removeIsolatedParam(TheParam); break; - case ChangeKind::RecordedDefaultedConstraint: - cs.removeDefaultedConstraint(Locator); + case ChangeKind::RecordedPreconcurrencyClosure: + cs.removePreconcurrencyClosure(TheClosure); + break; + + case ChangeKind::RecordedKeyPath: + cs.removeKeyPath(KeyPath.Expr); break; } } @@ -325,14 +499,35 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, out.indent(indent); switch (Kind) { + +#define LOCATOR_CHANGE(Name, _) \ + case ChangeKind::Name: \ + out << "(" << #Name << " at "; \ + TheLocator->dump(&cs.getASTContext().SourceMgr, out); \ + out << ")\n"; \ + break; +#define EXPR_CHANGE(Name) \ + case ChangeKind::Name: \ + out << "(" << #Name << " "; \ + simple_display(out, TheExpr); \ + out << ")\n"; \ + break; +#define CLOSURE_CHANGE(Name) \ + case ChangeKind::Name: \ + out << "(" << #Name << " "; \ + simple_display(out, TheClosure); \ + out << ")\n"; \ + break; +#include "swift/Sema/CSTrail.def" + case ChangeKind::AddedTypeVariable: - out << "(added type variable "; + out << "(AddedTypeVariable "; TypeVar->print(out, PO); out << ")\n"; break; case ChangeKind::AddedConstraint: - out << "(added constraint "; + out << "(AddedConstraint "; TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, indent + 2); out << " to type variable "; @@ -341,7 +536,7 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::RemovedConstraint: - out << "(removed constraint "; + out << "(RemovedConstraint "; TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, indent + 2); out << " from type variable "; @@ -350,14 +545,14 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::ExtendedEquivalenceClass: { - out << "(equivalence "; + out << "(ExtendedEquivalenceClass "; EquivClass.TypeVar->print(out, PO); out << " " << EquivClass.PrevSize << ")\n"; break; } case ChangeKind::RelatedTypeVariables: - out << "(related type variable "; + out << "(RelatedTypeVariables "; Relation.TypeVar->print(out, PO); out << " with "; Relation.OtherTypeVar->print(out, PO); @@ -365,7 +560,7 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::InferredBindings: - out << "(inferred bindings from "; + out << "(InferredBindings from "; TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, indent + 2); out << " for type variable "; @@ -374,7 +569,7 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::RetractedBindings: - out << "(retracted bindings from "; + out << "(RetractedBindings from "; TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, indent + 2); out << " for type variable "; @@ -383,7 +578,7 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::UpdatedTypeVariable: { - out << "(updated type variable "; + out << "(UpdatedTypeVariable "; Update.TypeVar->print(out, PO); auto parentOrFixed = Update.TypeVar->getImpl().ParentOrFixed; @@ -402,7 +597,7 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, } case ChangeKind::AddedConversionRestriction: - out << "(added restriction with source "; + out << "(AddedConversionRestriction with source "; Restriction.SrcType->print(out, PO); out << " and destination "; Restriction.DstType->print(out, PO); @@ -410,13 +605,13 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, break; case ChangeKind::AddedFix: - out << "(added a fix "; - Fix->print(out); + out << "(AddedFix "; + TheFix->print(out); out << ")\n"; break; case ChangeKind::AddedFixedRequirement: - out << "(added a fixed requirement "; + out << "(AddedFixedRequirement "; FixedRequirement.GP->print(out, PO); out << " kind "; out << Options << " "; @@ -424,56 +619,91 @@ void SolverTrail::Change::dump(llvm::raw_ostream &out, out << ")\n"; break; - case ChangeKind::RecordedDisjunctionChoice: - out << "(recorded disjunction choice at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); - out << " index "; - out << Options << ")\n"; + case ChangeKind::RecordedOpenedPackExpansionType: + out << "(RecordedOpenedPackExpansionType for "; + TheExpansion->print(out, PO); + out << ")\n"; break; - case ChangeKind::RecordedAppliedDisjunction: - out << "(recorded applied disjunction at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); + case ChangeKind::RecordedPackEnvironment: + // FIXME: Print short form of PackExpansionExpr + out << "(RecordedPackEnvironment "; + simple_display(out, TheElement); + out << "\n"; + break; + + case ChangeKind::RecordedNodeType: + out << "(RecordedNodeType at "; + Node.Node.getStartLoc().print(out, cs.getASTContext().SourceMgr); + out << " previous "; + if (Node.OldType) + Node.OldType->print(out, PO); + else + out << "null"; out << ")\n"; break; - case ChangeKind::RecordedMatchCallArgumentResult: - out << "(recorded argument matching choice at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); - out << ")\n"; + case ChangeKind::RecordedKeyPathComponentType: + out << "(RecordedKeyPathComponentType "; + simple_display(out, KeyPath.Expr); + out << " with component type "; + if (Node.OldType) + Node.OldType->print(out, PO); + else + out << "null"; + out << " for component " << Options << ")\n"; break; - case ChangeKind::RecordedOpenedTypes: - out << "(recorded list of opened types at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); + case ChangeKind::DisabledConstraint: + out << "(DisabledConstraint "; + TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, + indent + 2); out << ")\n"; break; - case ChangeKind::RecordedOpenedExistentialType: - out << "(recorded opened existential type at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); + case ChangeKind::FavoredConstraint: + out << "(FavoredConstraint "; + TheConstraint.Constraint->print(out, &cs.getASTContext().SourceMgr, + indent + 2); out << ")\n"; break; - case ChangeKind::RecordedOpenedPackExpansionType: - out << "(recorded opened pack expansion type for "; - ExpansionTy->print(out, PO); + case ChangeKind::RecordedResultBuilderTransform: + out << "(RecordedResultBuilderTransform "; + simple_display(out, TheRef); out << ")\n"; break; - case ChangeKind::RecordedPackExpansionEnvironment: - out << "(recorded pack expansion environment at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); + case ChangeKind::RecordedContextualInfo: + // FIXME: Print short form of ASTNode + out << "(RecordedContextualInfo)\n"; + break; + + case ChangeKind::RecordedTarget: + out << "(RecordedTarget "; + getSyntacticElementTargetKey().dump(out); out << ")\n"; break; - case ChangeKind::RecordedPackEnvironment: - out << "(recorded pack environment)\n"; + case ChangeKind::RecordedCaseLabelItemInfo: + // FIXME: Print something here + out << "(RecordedCaseLabelItemInfo)\n"; + break; + + case ChangeKind::RecordedPotentialThrowSite: + // FIXME: Print something here + out << "(RecordedPotentialThrowSite)\n"; break; - case ChangeKind::RecordedDefaultedConstraint: - out << "(recorded defaulted constraint at "; - Locator->dump(&cs.getASTContext().SourceMgr, out); + case ChangeKind::RecordedIsolatedParam: + out << "(RecordedIsolatedParam "; + TheParam->dumpRef(out); + out << ")\n"; + break; + + case ChangeKind::RecordedKeyPath: + out << "(RecordedKeyPath "; + simple_display(out, KeyPath.Expr); out << ")\n"; break; } @@ -485,6 +715,7 @@ void SolverTrail::recordChange(Change change) { Changes.push_back(change); + ++Profile[unsigned(change.Kind)]; ++Total; if (Changes.size() > Max) Max = Changes.size(); @@ -496,9 +727,19 @@ void SolverTrail::undo(unsigned toIndex) { if (CS.inInvalidState()) return; + auto dumpHistogram = [&]() { +#define CHANGE(Name) \ + if (auto count = Profile[unsigned(ChangeKind::Name)]) \ + llvm::dbgs() << "* " << #Name << ": " << count << "\n"; +#include "swift/Sema/CSTrail.def" + }; + LLVM_DEBUG(llvm::dbgs() << "decisions " << Changes.size() << " max " << Max - << " total " << Total << "\n"); + << " total " << Total << "\n"; + dumpHistogram(); + llvm::dbgs() << "\n"); + ASSERT(Changes.size() >= toIndex && "Trail corrupted"); ASSERT(!UndoActive); UndoActive = true; diff --git a/lib/Sema/ConstraintGraph.cpp b/lib/Sema/ConstraintGraph.cpp index f6d3f14d28ad3..096a99541515e 100644 --- a/lib/Sema/ConstraintGraph.cpp +++ b/lib/Sema/ConstraintGraph.cpp @@ -84,7 +84,7 @@ ConstraintGraph::lookupNode(TypeVariableType *typeVar) { // recordChange() to assert if there's an active undo. It is not valid to // create new nodes during an undo. if (CS.solverState) - CS.recordChange(SolverTrail::Change::addedTypeVariable(typeVar)); + CS.recordChange(SolverTrail::Change::AddedTypeVariable(typeVar)); // If this type variable is not the representative of its equivalence class, // add it to its representative's set of equivalences. @@ -400,7 +400,7 @@ void ConstraintGraph::addConstraint(Constraint *constraint) { for (auto typeVar : referencedTypeVars) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::addedConstraint(typeVar, constraint)); + CS.recordChange(SolverTrail::Change::AddedConstraint(typeVar, constraint)); addConstraint(typeVar, constraint); @@ -420,7 +420,7 @@ void ConstraintGraph::addConstraint(Constraint *constraint) { if (referencedTypeVars.empty()) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::addedConstraint(nullptr, constraint)); + CS.recordChange(SolverTrail::Change::AddedConstraint(nullptr, constraint)); addConstraint(nullptr, constraint); } @@ -455,7 +455,7 @@ void ConstraintGraph::removeConstraint(Constraint *constraint) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::removedConstraint(typeVar, constraint)); + CS.recordChange(SolverTrail::Change::RemovedConstraint(typeVar, constraint)); removeConstraint(typeVar, constraint); } @@ -464,7 +464,7 @@ void ConstraintGraph::removeConstraint(Constraint *constraint) { if (referencedTypeVars.empty()) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::removedConstraint(nullptr, constraint)); + CS.recordChange(SolverTrail::Change::RemovedConstraint(nullptr, constraint)); removeConstraint(nullptr, constraint); } @@ -503,7 +503,7 @@ void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1, // Record the change, if there are active scopes. if (CS.isRecordingChanges()) { CS.recordChange( - SolverTrail::Change::extendedEquivalenceClass( + SolverTrail::Change::ExtendedEquivalenceClass( typeVarRep, repNode.getEquivalenceClass().size())); } @@ -533,7 +533,7 @@ void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) { // Record the change, if there are active scopes. if (CS.isRecordingChanges()) - CS.recordChange(SolverTrail::Change::relatedTypeVariables(typeVar, otherTypeVar)); + CS.recordChange(SolverTrail::Change::RelatedTypeVariables(typeVar, otherTypeVar)); } } diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 3d14ac3a5e5c4..9cf9c4269aa0e 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -266,8 +266,8 @@ void ConstraintSystem::addConversionRestriction( if (!inserted) return; - if (isRecordingChanges()) { - recordChange(SolverTrail::Change::addedConversionRestriction( + if (solverState) { + recordChange(SolverTrail::Change::AddedConversionRestriction( srcType, dstType)); } } @@ -281,11 +281,10 @@ void ConstraintSystem::removeConversionRestriction( void ConstraintSystem::addFix(ConstraintFix *fix) { bool inserted = Fixes.insert(fix); - if (!inserted) - return; + ASSERT(inserted); - if (isRecordingChanges()) - recordChange(SolverTrail::Change::addedFix(fix)); + if (solverState) + recordChange(SolverTrail::Change::AddedFix(fix)); } void ConstraintSystem::removeFix(ConstraintFix *fix) { @@ -294,32 +293,23 @@ void ConstraintSystem::removeFix(ConstraintFix *fix) { } void ConstraintSystem::recordDisjunctionChoice( - ConstraintLocator *locator, - unsigned index) { - // We shouldn't ever register disjunction choices multiple times. - auto inserted = DisjunctionChoices.insert( - std::make_pair(locator, index)); - if (!inserted.second) { - ASSERT(inserted.first->second == index); - return; - } + ConstraintLocator *locator, unsigned index) { + bool inserted = DisjunctionChoices.insert({locator, index}).second; + ASSERT(inserted); - if (isRecordingChanges()) { - recordChange(SolverTrail::Change::recordedDisjunctionChoice( - locator, index)); - } + if (solverState) + recordChange(SolverTrail::Change::RecordedDisjunctionChoice(locator)); } void ConstraintSystem::recordAppliedDisjunction( ConstraintLocator *locator, FunctionType *fnType) { // We shouldn't ever register disjunction choices multiple times. - auto inserted = AppliedDisjunctions.insert( - std::make_pair(locator, fnType)); - if (inserted.second) { - if (isRecordingChanges()) { - recordChange(SolverTrail::Change::recordedAppliedDisjunction(locator)); - } - } + bool inserted = AppliedDisjunctions.insert( + std::make_pair(locator, fnType)).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedAppliedDisjunction(locator)); } /// Retrieve a dynamic result signature for the given declaration. @@ -445,6 +435,18 @@ bool ConstraintSystem::containsIDEInspectionTarget( Context.SourceMgr); } +void ConstraintSystem::recordPotentialThrowSite( + CatchNode catchNode, PotentialThrowSite site) { + potentialThrowSites.push_back({catchNode, site}); + if (solverState) + recordChange(SolverTrail::Change::RecordedPotentialThrowSite(catchNode)); +} + +void ConstraintSystem::removePotentialThrowSite(CatchNode catchNode) { + ASSERT(potentialThrowSites.back().first == catchNode); + potentialThrowSites.pop_back(); +} + void ConstraintSystem::recordPotentialThrowSite( PotentialThrowSite::Kind kind, Type type, ConstraintLocatorBuilder locator) { @@ -471,9 +473,8 @@ void ConstraintSystem::recordPotentialThrowSite( // do..catch statements without an explicit `throws` clause do infer // thrown types. if (auto doCatch = catchNode.dyn_cast()) { - potentialThrowSites.push_back( - {catchNode, - PotentialThrowSite{kind, type, getConstraintLocator(locator)}}); + PotentialThrowSite site{kind, type, getConstraintLocator(locator)}; + recordPotentialThrowSite(catchNode, site); return; } @@ -486,9 +487,8 @@ void ConstraintSystem::recordPotentialThrowSite( if (!closureEffects(closure).isThrowing()) return; - potentialThrowSites.push_back( - {catchNode, - PotentialThrowSite{kind, type, getConstraintLocator(locator)}}); + PotentialThrowSite site{kind, type, getConstraintLocator(locator)}; + recordPotentialThrowSite(catchNode, site); } Type ConstraintSystem::getCaughtErrorType(CatchNode catchNode) { @@ -853,10 +853,10 @@ std::pair ConstraintSystem::openExistentialType( void ConstraintSystem::recordOpenedExistentialType( ConstraintLocator *locator, OpenedArchetypeType *opened) { bool inserted = OpenedExistentialTypes.insert({locator, opened}).second; - if (inserted) { - if (isRecordingChanges()) - recordChange(SolverTrail::Change::recordedOpenedExistentialType(locator)); - } + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedOpenedExistentialType(locator)); } GenericEnvironment * @@ -894,12 +894,10 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator, void ConstraintSystem::recordPackExpansionEnvironment( ConstraintLocator *locator, std::pair uuidAndShape) { bool inserted = PackExpansionEnvironments.insert({locator, uuidAndShape}).second; - if (inserted) { - if (isRecordingChanges()) { - recordChange( - SolverTrail::Change::recordedPackExpansionEnvironment(locator)); - } - } + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedPackExpansionEnvironment(locator)); } PackExpansionExpr * @@ -910,12 +908,11 @@ ConstraintSystem::getPackEnvironment(PackElementExpr *packElement) const { void ConstraintSystem::addPackEnvironment(PackElementExpr *packElement, PackExpansionExpr *packExpansion) { - bool inserted = - PackEnvironments.insert({packElement, packExpansion}).second; - if (inserted) { - if (isRecordingChanges()) - recordChange(SolverTrail::Change::recordedPackEnvironment(packElement)); - } + bool inserted = PackEnvironments.insert({packElement, packExpansion}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedPackEnvironment(packElement)); } /// Extend the given depth map by adding depths for all of the subexpressions @@ -1028,7 +1025,8 @@ Type ConstraintSystem::openUnboundGenericType(GenericTypeDecl *decl, openGeneric(decl->getDeclContext(), decl->getGenericSignature(), locator, replacements); - recordOpenedTypes(locator, replacements); + // FIXME: Get rid of fixmeAllowDuplicates. + recordOpenedTypes(locator, replacements, /*fixmeAllowDuplicates=*/true); if (parentTy) { const auto parentTyInContext = @@ -1278,10 +1276,10 @@ Type ConstraintSystem::openPackExpansionType(PackExpansionType *expansion, void ConstraintSystem::recordOpenedPackExpansionType(PackExpansionType *expansion, TypeVariableType *expansionVar) { bool inserted = OpenedPackExpansionTypes.insert({expansion, expansionVar}).second; - if (inserted) { - if (isRecordingChanges()) - recordChange(SolverTrail::Change::recordedOpenedPackExpansionType(expansion)); - } + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedOpenedPackExpansionType(expansion)); } Type ConstraintSystem::openOpaqueType(OpaqueTypeArchetypeType *opaque, @@ -1687,20 +1685,16 @@ Type ConstraintSystem::getUnopenedTypeOfReference( void ConstraintSystem::recordOpenedType( ConstraintLocator *locator, ArrayRef openedTypes) { bool inserted = OpenedTypes.insert({locator, openedTypes}).second; - if (inserted) { - if (isRecordingChanges()) - recordChange(SolverTrail::Change::recordedOpenedTypes(locator)); - } -} + ASSERT(inserted); -void ConstraintSystem::removeOpenedType(ConstraintLocator *locator) { - bool erased = OpenedTypes.erase(locator); - ASSERT(erased); + if (solverState) + recordChange(SolverTrail::Change::RecordedOpenedTypes(locator)); } void ConstraintSystem::recordOpenedTypes( ConstraintLocatorBuilder locator, - const OpenedTypeMap &replacements) { + const OpenedTypeMap &replacements, + bool fixmeAllowDuplicates) { if (replacements.empty()) return; @@ -1721,7 +1715,10 @@ void ConstraintSystem::recordOpenedTypes( OpenedType* openedTypes = Allocator.Allocate(replacements.size()); std::copy(replacements.begin(), replacements.end(), openedTypes); - recordOpenedType( + + // FIXME: Get rid of fixmeAllowDuplicates. + if (!fixmeAllowDuplicates || OpenedTypes.count(locatorPtr) == 0) + recordOpenedType( locatorPtr, llvm::ArrayRef(openedTypes, replacements.size())); } @@ -3545,12 +3542,13 @@ void ConstraintSystem::bindOverloadType( // Associate an argument list for the implicit x[dynamicMember:] subscript // if we haven't already. - auto *&argList = ArgumentLists[getArgumentInfoLocator(callLoc)]; - if (!argList) { - argList = ArgumentList::createImplicit( + auto *argLoc = getArgumentInfoLocator(callLoc); + if (ArgumentLists.find(argLoc) == ArgumentLists.end()) { + auto *argList = ArgumentList::createImplicit( ctx, {Argument(SourceLoc(), ctx.Id_dynamicMember, /*expr*/ nullptr)}, /*firstTrailingClosureIndex=*/std::nullopt, AllocationArena::ConstraintSolver); + recordArgumentList(argLoc, argList); } auto *callerTy = FunctionType::get( @@ -3743,6 +3741,15 @@ void ConstraintSystem::bindOverloadType( llvm_unreachable("Unhandled OverloadChoiceKind in switch."); } +void ConstraintSystem::recordResolvedOverload(ConstraintLocator *locator, + SelectedOverload overload) { + bool inserted = ResolvedOverloads.insert({locator, overload}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::ResolvedOverload(locator)); +} + void ConstraintSystem::resolveOverload(ConstraintLocator *locator, Type boundType, OverloadChoice choice, @@ -3998,9 +4005,7 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, auto overload = SelectedOverload{ choice, openedType, adjustedOpenedType, refType, adjustedRefType, boundType}; - auto result = ResolvedOverloads.insert({locator, overload}); - assert(result.second && "Already resolved this overload?"); - (void)result; + recordResolvedOverload(locator, overload); // Add the constraints necessary to bind the overload type. bindOverloadType(overload, boundType, locator, useDC, @@ -4481,12 +4486,12 @@ size_t Solution::getTotalMemory() const { size_in_bytes(targets) + size_in_bytes(caseLabelItems) + size_in_bytes(exprPatterns) + - (isolatedParams.size() * sizeof(void *)) + - (preconcurrencyClosures.size() * sizeof(void *)) + + size_in_bytes(isolatedParams) + + size_in_bytes(preconcurrencyClosures) + size_in_bytes(resultBuilderTransformed) + size_in_bytes(appliedPropertyWrappers) + size_in_bytes(argumentLists) + - ImplicitCallAsFunctionRoots.getMemorySize() + + size_in_bytes(ImplicitCallAsFunctionRoots) + size_in_bytes(SynthesizedConformances); } @@ -6517,13 +6522,20 @@ ArgumentList *ConstraintSystem::getArgumentList(ConstraintLocator *locator) { return nullptr; } +void ConstraintSystem::recordArgumentList(ConstraintLocator *locator, + ArgumentList *args) { + bool inserted = ArgumentLists.insert({locator, args}).second; + ASSERT(inserted); + + if (solverState) + recordChange(SolverTrail::Change::RecordedArgumentList(locator)); +} + void ConstraintSystem::associateArgumentList(ConstraintLocator *locator, ArgumentList *args) { - assert(locator && locator->getAnchor()); - auto *argInfoLoc = getArgumentInfoLocator(locator); - auto inserted = ArgumentLists.insert({argInfoLoc, args}).second; - assert(inserted && "Multiple argument lists at locator?"); - (void)inserted; + ASSERT(locator && locator->getAnchor()); + auto *argLoc = getArgumentInfoLocator(locator); + recordArgumentList(argLoc, args); } ArgumentList *Solution::getArgumentList(ConstraintLocator *locator) const { @@ -7443,8 +7455,8 @@ void ConstraintSystem::recordFixedRequirement(GenericTypeParamType *GP, bool inserted = FixedRequirements.insert( std::make_tuple(GP, reqKind, requirementTy.getPointer())).second; if (inserted) { - if (isRecordingChanges()) { - recordChange(SolverTrail::Change::addedFixedRequirement( + if (solverState) { + recordChange(SolverTrail::Change::AddedFixedRequirement( GP, reqKind, requirementTy)); } }