diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 3e17979652d90..825990660c37b 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -469,10 +469,6 @@ class BindingSet { /// checking. bool isViable(PotentialBinding &binding, bool isTransitive); - explicit operator bool() const { - return hasViableBindings() || isDirectHole(); - } - /// Determine whether this set has any "viable" (or non-hole) bindings. /// /// A viable binding could be - a direct or transitive binding @@ -486,6 +482,12 @@ class BindingSet { !Defaults.empty(); } + /// Determine whether this set can be chosen as the next binding set + /// to attempt. + bool isViable() const { + return hasViableBindings() || isDirectHole(); + } + ArrayRef getConformanceRequirements() const { return Protocols; } @@ -544,6 +546,8 @@ class BindingSet { /// Check if this binding is favored over a conjunction. bool favoredOverConjunction(Constraint *conjunction) const; + void inferTransitiveKeyPathBindings(); + /// Detect `subtype` relationship between two type variables and /// attempt to infer supertype bindings transitively e.g. /// @@ -553,19 +557,27 @@ class BindingSet { /// /// \param inferredBindings The set of all bindings inferred for type /// variables in the workset. - void inferTransitiveBindings(); + void inferTransitiveSupertypeBindings(); + + void inferTransitiveUnresolvedMemberRefBindings(); /// Detect subtype, conversion or equivalence relationship /// between two type variables and attempt to propagate protocol /// requirements down the subtype or equivalence chain. void inferTransitiveProtocolRequirements(); - /// Finalize binding computation for this type variable by - /// inferring bindings from context e.g. transitive bindings. + /// Check whether the given binding set covers any of the + /// literal protocols associated with this type variable. + void determineLiteralCoverage(); + + /// Finalize binding computation for key path type variables. /// /// \returns true if finalization successful (which makes binding set viable), /// and false otherwise. - bool finalize(bool transitive); + bool finalizeKeyPathBindings(); + + /// Handle diagnostics of unresolved member chains. + void finalizeUnresolvedMemberChainResult(); static BindingScore formBindingScore(const BindingSet &b); @@ -590,10 +602,6 @@ class BindingSet { void addDefault(Constraint *constraint); - /// Check whether the given binding set covers any of the - /// literal protocols associated with this type variable. - void determineLiteralCoverage(); - StringRef getLiteralBindingKind(LiteralBindingKind K) const { #define ENTRY(Kind, String) \ case LiteralBindingKind::Kind: \ diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 04a4704344cbc..785d5f8d87c6a 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -103,8 +103,7 @@ bool BindingSet::isDirectHole() const { if (!CS.shouldAttemptFixes()) return false; - return Bindings.empty() && getNumViableLiteralBindings() == 0 && - Defaults.empty() && TypeVar->getImpl().canBindToHole(); + return !hasViableBindings() && TypeVar->getImpl().canBindToHole(); } static bool isGenericParameter(TypeVariableType *TypeVar) { @@ -494,9 +493,7 @@ void BindingSet::inferTransitiveProtocolRequirements() { } while (!workList.empty()); } -void BindingSet::inferTransitiveBindings() { - using BindingKind = AllowedBindingKind; - +void BindingSet::inferTransitiveKeyPathBindings() { // If the current type variable represents a key path root type // let's try to transitively infer its type through bindings of // a key path type. @@ -551,7 +548,7 @@ void BindingSet::inferTransitiveBindings() { } } else { addBinding( - binding.withSameSource(inferredRootTy, BindingKind::Exact), + binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact), /*isTransitive=*/true); } } @@ -559,7 +556,9 @@ void BindingSet::inferTransitiveBindings() { } } } +} +void BindingSet::inferTransitiveSupertypeBindings() { for (const auto &entry : Info.SupertypeOf) { auto &node = CS.getConstraintGraph()[entry.first]; if (!node.hasBindingSet()) @@ -609,8 +608,8 @@ void BindingSet::inferTransitiveBindings() { // either be Exact or Supertypes in order for it to make sense // to add Supertype bindings based on the relationship between // our type variables. - if (binding.Kind != BindingKind::Exact && - binding.Kind != BindingKind::Supertypes) + if (binding.Kind != AllowedBindingKind::Exact && + binding.Kind != AllowedBindingKind::Supertypes) continue; auto type = binding.BindingType; @@ -621,12 +620,49 @@ void BindingSet::inferTransitiveBindings() { if (ConstraintSystem::typeVarOccursInType(TypeVar, type)) continue; - addBinding(binding.withSameSource(type, BindingKind::Supertypes), + addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes), /*isTransitive=*/true); } } } +void BindingSet::inferTransitiveUnresolvedMemberRefBindings() { + if (!hasViableBindings()) { + if (auto *locator = TypeVar->getImpl().getLocator()) { + if (locator->isLastElement()) { + // If this is a base of an unresolved member chain, as a last + // resort effort let's infer base to be a protocol type based + // on contextual conformance requirements. + // + // This allows us to find solutions in cases like this: + // + // \code + // func foo(_: T) {} + // foo(.bar) <- `.bar` should be a static member of `P`. + // \endcode + inferTransitiveProtocolRequirements(); + + if (TransitiveProtocols.has_value()) { + for (auto *constraint : *TransitiveProtocols) { + Type protocolTy = constraint->getSecondType(); + + // Compiler-known marker protocols cannot be extended with members, + // so do not consider them. + if (auto p = protocolTy->getAs()) { + if (ProtocolDecl *decl = p->getDecl()) + if (decl->getKnownProtocolKind() && decl->isMarkerProtocol()) + continue; + } + + addBinding({protocolTy, AllowedBindingKind::Exact, constraint}, + /*isTransitive=*/false); + } + } + } + } + } +} + static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability, Type rootType, Type valueType) { KeyPathMutability mutability; @@ -664,51 +700,11 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability, return keyPathTy; } -bool BindingSet::finalize(bool transitive) { - if (transitive) - inferTransitiveBindings(); - - determineLiteralCoverage(); - +bool BindingSet::finalizeKeyPathBindings() { if (auto *locator = TypeVar->getImpl().getLocator()) { - if (locator->isLastElement()) { - // If this is a base of an unresolved member chain, as a last - // resort effort let's infer base to be a protocol type based - // on contextual conformance requirements. - // - // This allows us to find solutions in cases like this: - // - // \code - // func foo(_: T) {} - // foo(.bar) <- `.bar` should be a static member of `P`. - // \endcode - if (transitive && !hasViableBindings()) { - inferTransitiveProtocolRequirements(); - - if (TransitiveProtocols.has_value()) { - for (auto *constraint : *TransitiveProtocols) { - Type protocolTy = constraint->getSecondType(); - - // Compiler-known marker protocols cannot be extended with members, - // so do not consider them. - if (auto p = protocolTy->getAs()) { - if (ProtocolDecl *decl = p->getDecl()) - if (decl->getKnownProtocolKind() && decl->isMarkerProtocol()) - continue; - } - - addBinding({protocolTy, AllowedBindingKind::Exact, constraint}, - /*isTransitive=*/false); - } - } - } - } - if (TypeVar->getImpl().isKeyPathType()) { auto &ctx = CS.getASTContext(); - - auto *keyPathLoc = TypeVar->getImpl().getLocator(); - auto *keyPath = castToExpr(keyPathLoc->getAnchor()); + auto *keyPath = castToExpr(locator->getAnchor()); bool isValid; std::optional capability; @@ -775,7 +771,7 @@ bool BindingSet::finalize(bool transitive) { auto keyPathTy = getKeyPathType(ctx, *capability, rootTy, CS.getKeyPathValueType(keyPath)); updatedBindings.insert( - {keyPathTy, AllowedBindingKind::Exact, keyPathLoc}); + {keyPathTy, AllowedBindingKind::Exact, locator}); } else if (CS.shouldAttemptFixes()) { auto fixedRootTy = CS.getFixedType(rootTy); // If key path is structurally correct and has a resolved root @@ -802,10 +798,14 @@ bool BindingSet::finalize(bool transitive) { Bindings = std::move(updatedBindings); Defaults.clear(); - - return true; } + } + return true; +} + +void BindingSet::finalizeUnresolvedMemberChainResult() { + if (auto *locator = TypeVar->getImpl().getLocator()) { if (CS.shouldAttemptFixes() && locator->isLastElement()) { // Let's see whether this chain is valid, if it isn't then to avoid @@ -828,8 +828,6 @@ bool BindingSet::finalize(bool transitive) { } } } - - return true; } void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) { @@ -1143,37 +1141,6 @@ std::optional ConstraintSystem::determineBestBindings( node.initBindingSet(); } - // Determine whether given type variable with its set of bindings is - // viable to be attempted on the next step of the solver. If type variable - // has no "direct" bindings of any kind e.g. direct bindings to concrete - // types, default types from "defaultable" constraints or literal - // conformances, such type variable is not viable to be evaluated to be - // attempted next. - auto isViableForRanking = [this](const BindingSet &bindings) -> bool { - auto *typeVar = bindings.getTypeVariable(); - - // Key path root type variable is always viable because it can be - // transitively inferred from key path type during binding set - // finalization. - if (typeVar->getImpl().isKeyPathRoot()) - return true; - - // Type variable representing a base of unresolved member chain should - // always be considered viable for ranking since it's allow to infer - // types from transitive protocol requirements. - if (auto *locator = typeVar->getImpl().getLocator()) { - if (locator->isLastElement()) - return true; - } - - // If type variable is marked as a potential hole there is always going - // to be at least one binding available for it. - if (shouldAttemptFixes() && typeVar->getImpl().canBindToHole()) - return true; - - return bool(bindings); - }; - // Now let's see if we could infer something for related type // variables based on other bindings. for (auto *typeVar : getTypeVariables()) { @@ -1183,6 +1150,16 @@ std::optional ConstraintSystem::determineBestBindings( auto &bindings = node.getBindingSet(); + // Special handling for key paths. + bindings.inferTransitiveKeyPathBindings(); + if (!bindings.finalizeKeyPathBindings()) + continue; + + // Special handling for "leading-dot" unresolved member references, + // like .foo. + bindings.inferTransitiveUnresolvedMemberRefBindings(); + bindings.finalizeUnresolvedMemberChainResult(); + // Before attempting to infer transitive bindings let's check // whether there are any viable "direct" bindings associated with // current type variable, if there are none - it means that this type @@ -1193,12 +1170,12 @@ std::optional ConstraintSystem::determineBestBindings( // associated with given type variable, any default constraints, // or any conformance requirements to literal protocols with can // produce a default type. - bool isViable = isViableForRanking(bindings); + bool isViable = bindings.isViable(); - if (!bindings.finalize(true)) - continue; + bindings.inferTransitiveSupertypeBindings(); + bindings.determineLiteralCoverage(); - if (!bindings || !isViable) + if (!isViable) continue; onCandidate(bindings); @@ -1591,7 +1568,10 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) { assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type"); BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings()); - bindings.finalize(false); + + (void) bindings.finalizeKeyPathBindings(); + bindings.finalizeUnresolvedMemberChainResult(); + bindings.determineLiteralCoverage(); return bindings; } diff --git a/lib/Sema/CSOptimizer.cpp b/lib/Sema/CSOptimizer.cpp index 31992713ceb63..aab3431df14bb 100644 --- a/lib/Sema/CSOptimizer.cpp +++ b/lib/Sema/CSOptimizer.cpp @@ -1105,7 +1105,9 @@ static void determineBestChoicesInContext( // Simply adding it as a binding won't work because if the second argument // is non-optional the overload that returns `T?` would still have a lower // score. - if (!bindingSet && isNilCoalescingOperator(disjunction)) { + if (!bindingSet.hasViableBindings() && + !bindingSet.isDirectHole() && + isNilCoalescingOperator(disjunction)) { auto &cg = cs.getConstraintGraph(); if (llvm::any_of(cg[typeVar].getConstraints(), [&typeVar](Constraint *constraint) { diff --git a/lib/Sema/ConstraintGraph.cpp b/lib/Sema/ConstraintGraph.cpp index 1c47c257ec983..346e2e175909d 100644 --- a/lib/Sema/ConstraintGraph.cpp +++ b/lib/Sema/ConstraintGraph.cpp @@ -921,7 +921,8 @@ bool ConstraintGraph::contractEdges() { // us enough information to decided on l-valueness. if (tyvar1->getImpl().canBindToInOut()) { bool isNotContractable = true; - if (auto bindings = CS.getBindingsFor(tyvar1)) { + auto bindings = CS.getBindingsFor(tyvar1); + if (bindings.isViable()) { // Holes can't be contracted. if (bindings.isHole()) continue; diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index fdd6c6682a489..eef3bb841ac17 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -125,7 +125,15 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { cs.getConstraintGraph()[floatLiteralTy].initBindingSet(); - bindings.finalize(/*transitive=*/true); + bindings.inferTransitiveKeyPathBindings(); + (void) bindings.finalizeKeyPathBindings(); + + bindings.inferTransitiveUnresolvedMemberRefBindings(); + bindings.finalizeUnresolvedMemberChainResult(); + + bindings.inferTransitiveSupertypeBindings(); + + bindings.determineLiteralCoverage(); // Inferred a single transitive binding through `$T_float`. ASSERT_EQ(bindings.Bindings.size(), (unsigned)1); diff --git a/unittests/Sema/SemaFixture.cpp b/unittests/Sema/SemaFixture.cpp index 5ee9c005a5e93..0600e9ad5929e 100644 --- a/unittests/Sema/SemaFixture.cpp +++ b/unittests/Sema/SemaFixture.cpp @@ -140,8 +140,20 @@ BindingSet SemaTest::inferBindings(ConstraintSystem &cs, continue; auto &bindings = node.getBindingSet(); + + // FIXME: This is also called in inferTransitiveUnresolvedMemberRefBindings(), + // why do we need to call it here too? bindings.inferTransitiveProtocolRequirements(); - bindings.finalize(/*transitive=*/true); + + bindings.inferTransitiveKeyPathBindings(); + (void) bindings.finalizeKeyPathBindings(); + + bindings.inferTransitiveUnresolvedMemberRefBindings(); + bindings.finalizeUnresolvedMemberChainResult(); + + bindings.inferTransitiveSupertypeBindings(); + + bindings.determineLiteralCoverage(); } auto &node = cs.getConstraintGraph()[typeVar];