Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions include/swift/Sema/CSBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,6 @@ class BindingSet {
/// checking.
bool isViable(PotentialBinding &binding, bool isTransitive);

explicit operator bool() const {
return hasViableBindings() || isDirectHole();
}

/// Determine whether this set has any "viable" (or non-hole) bindings.
///
/// A viable binding could be - a direct or transitive binding
Expand All @@ -486,6 +482,12 @@ class BindingSet {
!Defaults.empty();
}

/// Determine whether this set can be chosen as the next binding set
/// to attempt.
bool isViable() const {
return hasViableBindings() || isDirectHole();
}

ArrayRef<Constraint *> getConformanceRequirements() const {
return Protocols;
}
Expand Down Expand Up @@ -544,6 +546,8 @@ class BindingSet {
/// Check if this binding is favored over a conjunction.
bool favoredOverConjunction(Constraint *conjunction) const;

void inferTransitiveKeyPathBindings();

/// Detect `subtype` relationship between two type variables and
/// attempt to infer supertype bindings transitively e.g.
///
Expand All @@ -553,19 +557,27 @@ class BindingSet {
///
/// \param inferredBindings The set of all bindings inferred for type
/// variables in the workset.
void inferTransitiveBindings();
void inferTransitiveSupertypeBindings();

void inferTransitiveUnresolvedMemberRefBindings();

/// Detect subtype, conversion or equivalence relationship
/// between two type variables and attempt to propagate protocol
/// requirements down the subtype or equivalence chain.
void inferTransitiveProtocolRequirements();

/// Finalize binding computation for this type variable by
/// inferring bindings from context e.g. transitive bindings.
/// Check whether the given binding set covers any of the
/// literal protocols associated with this type variable.
void determineLiteralCoverage();

/// Finalize binding computation for key path type variables.
///
/// \returns true if finalization successful (which makes binding set viable),
/// and false otherwise.
bool finalize(bool transitive);
bool finalizeKeyPathBindings();

/// Handle diagnostics of unresolved member chains.
void finalizeUnresolvedMemberChainResult();

static BindingScore formBindingScore(const BindingSet &b);

Expand All @@ -590,10 +602,6 @@ class BindingSet {

void addDefault(Constraint *constraint);

/// Check whether the given binding set covers any of the
/// literal protocols associated with this type variable.
void determineLiteralCoverage();

StringRef getLiteralBindingKind(LiteralBindingKind K) const {
#define ENTRY(Kind, String) \
case LiteralBindingKind::Kind: \
Expand Down
164 changes: 72 additions & 92 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ bool BindingSet::isDirectHole() const {
if (!CS.shouldAttemptFixes())
return false;

return Bindings.empty() && getNumViableLiteralBindings() == 0 &&
Defaults.empty() && TypeVar->getImpl().canBindToHole();
return !hasViableBindings() && TypeVar->getImpl().canBindToHole();
}

static bool isGenericParameter(TypeVariableType *TypeVar) {
Expand Down Expand Up @@ -494,9 +493,7 @@ void BindingSet::inferTransitiveProtocolRequirements() {
} while (!workList.empty());
}

void BindingSet::inferTransitiveBindings() {
using BindingKind = AllowedBindingKind;

void BindingSet::inferTransitiveKeyPathBindings() {
// If the current type variable represents a key path root type
// let's try to transitively infer its type through bindings of
// a key path type.
Expand Down Expand Up @@ -551,15 +548,17 @@ void BindingSet::inferTransitiveBindings() {
}
} else {
addBinding(
binding.withSameSource(inferredRootTy, BindingKind::Exact),
binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact),
/*isTransitive=*/true);
}
}
}
}
}
}
}

void BindingSet::inferTransitiveSupertypeBindings() {
for (const auto &entry : Info.SupertypeOf) {
auto &node = CS.getConstraintGraph()[entry.first];
if (!node.hasBindingSet())
Expand Down Expand Up @@ -609,8 +608,8 @@ void BindingSet::inferTransitiveBindings() {
// either be Exact or Supertypes in order for it to make sense
// to add Supertype bindings based on the relationship between
// our type variables.
if (binding.Kind != BindingKind::Exact &&
binding.Kind != BindingKind::Supertypes)
if (binding.Kind != AllowedBindingKind::Exact &&
binding.Kind != AllowedBindingKind::Supertypes)
continue;

auto type = binding.BindingType;
Expand All @@ -621,12 +620,49 @@ void BindingSet::inferTransitiveBindings() {
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
continue;

addBinding(binding.withSameSource(type, BindingKind::Supertypes),
addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
/*isTransitive=*/true);
}
}
}

void BindingSet::inferTransitiveUnresolvedMemberRefBindings() {
if (!hasViableBindings()) {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
// If this is a base of an unresolved member chain, as a last
// resort effort let's infer base to be a protocol type based
// on contextual conformance requirements.
//
// This allows us to find solutions in cases like this:
//
// \code
// func foo<T: P>(_: T) {}
// foo(.bar) <- `.bar` should be a static member of `P`.
// \endcode
inferTransitiveProtocolRequirements();

if (TransitiveProtocols.has_value()) {
for (auto *constraint : *TransitiveProtocols) {
Type protocolTy = constraint->getSecondType();

// Compiler-known marker protocols cannot be extended with members,
// so do not consider them.
if (auto p = protocolTy->getAs<ProtocolType>()) {
if (ProtocolDecl *decl = p->getDecl())
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
continue;
}

addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
}
}
}
}

static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
Type rootType, Type valueType) {
KeyPathMutability mutability;
Expand Down Expand Up @@ -664,51 +700,11 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
return keyPathTy;
}

bool BindingSet::finalize(bool transitive) {
if (transitive)
inferTransitiveBindings();

determineLiteralCoverage();

bool BindingSet::finalizeKeyPathBindings() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
// If this is a base of an unresolved member chain, as a last
// resort effort let's infer base to be a protocol type based
// on contextual conformance requirements.
//
// This allows us to find solutions in cases like this:
//
// \code
// func foo<T: P>(_: T) {}
// foo(.bar) <- `.bar` should be a static member of `P`.
// \endcode
if (transitive && !hasViableBindings()) {
inferTransitiveProtocolRequirements();

if (TransitiveProtocols.has_value()) {
for (auto *constraint : *TransitiveProtocols) {
Type protocolTy = constraint->getSecondType();

// Compiler-known marker protocols cannot be extended with members,
// so do not consider them.
if (auto p = protocolTy->getAs<ProtocolType>()) {
if (ProtocolDecl *decl = p->getDecl())
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
continue;
}

addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
}
}

if (TypeVar->getImpl().isKeyPathType()) {
auto &ctx = CS.getASTContext();

auto *keyPathLoc = TypeVar->getImpl().getLocator();
auto *keyPath = castToExpr<KeyPathExpr>(keyPathLoc->getAnchor());
auto *keyPath = castToExpr<KeyPathExpr>(locator->getAnchor());

bool isValid;
std::optional<KeyPathCapability> capability;
Expand Down Expand Up @@ -775,7 +771,7 @@ bool BindingSet::finalize(bool transitive) {
auto keyPathTy = getKeyPathType(ctx, *capability, rootTy,
CS.getKeyPathValueType(keyPath));
updatedBindings.insert(
{keyPathTy, AllowedBindingKind::Exact, keyPathLoc});
{keyPathTy, AllowedBindingKind::Exact, locator});
} else if (CS.shouldAttemptFixes()) {
auto fixedRootTy = CS.getFixedType(rootTy);
// If key path is structurally correct and has a resolved root
Expand All @@ -802,10 +798,14 @@ bool BindingSet::finalize(bool transitive) {

Bindings = std::move(updatedBindings);
Defaults.clear();

return true;
}
}

return true;
}

void BindingSet::finalizeUnresolvedMemberChainResult() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (CS.shouldAttemptFixes() &&
locator->isLastElement<LocatorPathElt::UnresolvedMemberChainResult>()) {
// Let's see whether this chain is valid, if it isn't then to avoid
Expand All @@ -828,8 +828,6 @@ bool BindingSet::finalize(bool transitive) {
}
}
}

return true;
}

void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
Expand Down Expand Up @@ -1143,37 +1141,6 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
node.initBindingSet();
}

// Determine whether given type variable with its set of bindings is
// viable to be attempted on the next step of the solver. If type variable
// has no "direct" bindings of any kind e.g. direct bindings to concrete
// types, default types from "defaultable" constraints or literal
// conformances, such type variable is not viable to be evaluated to be
// attempted next.
auto isViableForRanking = [this](const BindingSet &bindings) -> bool {
auto *typeVar = bindings.getTypeVariable();

// Key path root type variable is always viable because it can be
// transitively inferred from key path type during binding set
// finalization.
if (typeVar->getImpl().isKeyPathRoot())
return true;

// Type variable representing a base of unresolved member chain should
// always be considered viable for ranking since it's allow to infer
// types from transitive protocol requirements.
if (auto *locator = typeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>())
return true;
}

// If type variable is marked as a potential hole there is always going
// to be at least one binding available for it.
if (shouldAttemptFixes() && typeVar->getImpl().canBindToHole())
return true;

return bool(bindings);
};

// Now let's see if we could infer something for related type
// variables based on other bindings.
for (auto *typeVar : getTypeVariables()) {
Expand All @@ -1183,6 +1150,16 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(

auto &bindings = node.getBindingSet();

// Special handling for key paths.
bindings.inferTransitiveKeyPathBindings();
if (!bindings.finalizeKeyPathBindings())
continue;

// Special handling for "leading-dot" unresolved member references,
// like .foo.
bindings.inferTransitiveUnresolvedMemberRefBindings();
bindings.finalizeUnresolvedMemberChainResult();

// Before attempting to infer transitive bindings let's check
// whether there are any viable "direct" bindings associated with
// current type variable, if there are none - it means that this type
Expand All @@ -1193,12 +1170,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
// associated with given type variable, any default constraints,
// or any conformance requirements to literal protocols with can
// produce a default type.
bool isViable = isViableForRanking(bindings);
bool isViable = bindings.isViable();

if (!bindings.finalize(true))
continue;
bindings.inferTransitiveSupertypeBindings();
bindings.determineLiteralCoverage();

if (!bindings || !isViable)
if (!isViable)
continue;

onCandidate(bindings);
Expand Down Expand Up @@ -1591,7 +1568,10 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");

BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
bindings.finalize(false);

(void) bindings.finalizeKeyPathBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();

return bindings;
}
Expand Down
4 changes: 3 additions & 1 deletion lib/Sema/CSOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,9 @@ static void determineBestChoicesInContext(
// Simply adding it as a binding won't work because if the second argument
// is non-optional the overload that returns `T?` would still have a lower
// score.
if (!bindingSet && isNilCoalescingOperator(disjunction)) {
if (!bindingSet.hasViableBindings() &&
!bindingSet.isDirectHole() &&
isNilCoalescingOperator(disjunction)) {
auto &cg = cs.getConstraintGraph();
if (llvm::any_of(cg[typeVar].getConstraints(),
[&typeVar](Constraint *constraint) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/ConstraintGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,8 @@ bool ConstraintGraph::contractEdges() {
// us enough information to decided on l-valueness.
if (tyvar1->getImpl().canBindToInOut()) {
bool isNotContractable = true;
if (auto bindings = CS.getBindingsFor(tyvar1)) {
auto bindings = CS.getBindingsFor(tyvar1);
if (bindings.isViable()) {
// Holes can't be contracted.
if (bindings.isHole())
continue;
Expand Down
10 changes: 9 additions & 1 deletion unittests/Sema/BindingInferenceTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {

cs.getConstraintGraph()[floatLiteralTy].initBindingSet();

bindings.finalize(/*transitive=*/true);
bindings.inferTransitiveKeyPathBindings();
(void) bindings.finalizeKeyPathBindings();

bindings.inferTransitiveUnresolvedMemberRefBindings();
bindings.finalizeUnresolvedMemberChainResult();

bindings.inferTransitiveSupertypeBindings();

bindings.determineLiteralCoverage();

// Inferred a single transitive binding through `$T_float`.
ASSERT_EQ(bindings.Bindings.size(), (unsigned)1);
Expand Down
Loading