diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 146126fd8a4f3..51fd3a39c0057 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -1014,6 +1014,11 @@ enum ScoreKind: unsigned int { SK_EmptyExistentialConversion, /// A key path application subscript. SK_KeyPathSubscript, + /// A pointer conversion where the destination type is a generic parameter. + /// This should eventually be removed in favor of outright banning pointer + /// conversions for generic parameters. As such we consider it more impactful + /// than \c SK_ValueToPointerConversion. + SK_GenericParamPointerConversion, /// A conversion from a string, array, or inout to a pointer. SK_ValueToPointerConversion, /// A closure/function conversion to an autoclosure parameter. @@ -1191,6 +1196,9 @@ struct Score { case SK_KeyPathSubscript: return "key path subscript"; + case SK_GenericParamPointerConversion: + return "pointer conversion to generic parameter"; + case SK_ValueToPointerConversion: return "value-to-pointer conversion"; diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 4cf84fecad1e8..b093b4f2ce385 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -15411,6 +15411,50 @@ ConstraintSystem::simplifyRestrictedConstraintImpl( llvm_unreachable("bad conversion restriction"); } +static void increaseScoreForGenericParamPointerConversion( + ConstraintSystem &cs, ConversionRestrictionKind kind, + ConstraintLocatorBuilder locator) { + switch (kind) { + case ConversionRestrictionKind::InoutToPointer: + case ConversionRestrictionKind::ArrayToPointer: + case ConversionRestrictionKind::StringToPointer: + case ConversionRestrictionKind::PointerToPointer: + break; + default: + return; + } + + auto *loc = cs.getConstraintLocator(locator); + auto argInfo = loc->findLast(); + if (!argInfo) + return; + + auto overload = cs.findSelectedOverloadFor(cs.getCalleeLocator(loc)); + if (!overload) + return; + + auto *D = overload->choice.getDeclOrNull(); + if (!D) + return; + + auto *param = getParameterAt(D, argInfo->getParamIdx()); + if (!param) + return; + + // Check to see if the parameter is a generic parameter, or dependent member. + auto paramTy = param->getInterfaceType()->lookThroughAllOptionalTypes(); + if (!paramTy->isTypeParameter()) + return; + + // Don't increase the score if the type parameter is rooted on the protocol + // 'Self' type, e.g extensions on `_Pointer` shouldn't be penalized. + if (auto *PD = D->getDeclContext()->getSelfProtocolDecl()) { + if (PD->getSelfInterfaceType()->isEqual(paramTy->getRootGenericParam())) + return; + } + cs.increaseScore(SK_GenericParamPointerConversion, locator); +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyRestrictedConstraint( ConversionRestrictionKind restriction, @@ -15438,6 +15482,13 @@ ConstraintSystem::simplifyRestrictedConstraint( addFixConstraint(fix, matchKind, type1, type2, locator); } + // Increase the score if needed for a pointer conversion to a generic + // parameter type. + // FIXME: We ought to consider outright banning pointer conversions to + // generic parameter types, in which case this logic could be adjusted to + // record a fix instead. + increaseScoreForGenericParamPointerConversion(*this, restriction, locator); + addConversionRestriction(type1, type2, restriction); return SolutionKind::Solved; } diff --git a/test/Constraints/valid_pointer_conversions.swift b/test/Constraints/valid_pointer_conversions.swift index f759c41f326e0..8e5ed0b07b067 100644 --- a/test/Constraints/valid_pointer_conversions.swift +++ b/test/Constraints/valid_pointer_conversions.swift @@ -88,3 +88,31 @@ do { let _: UInt8 = result1 // Ok let _: [UInt8] = result2 // Ok } + +protocol PointerProtocol {} +extension UnsafePointer: PointerProtocol {} + +extension PointerProtocol { + func foo(_ x: Self) {} // expected-note {{found this candidate}} + func foo(_ x: UnsafePointer) {} // expected-note {{found this candidate}} +} + +func testGenericPointerConversions( + chars: [CChar], mutablePtr: UnsafeMutablePointer, ptr: UnsafePointer +) { + func id(_ x: T) -> T { x } + func optID(_ x: T?) -> T { x! } + func takesCharPtrs(_: UnsafePointer, _: UnsafePointer?) {} + + // Make sure we don't end up with an ambiguity here, we should prefer to + // do the pointer conversion for `takesPtrs` not `id`. + takesCharPtrs(chars, "a") + takesCharPtrs(id(chars), id("a")) + takesCharPtrs(id("a"), optID(chars)) + takesCharPtrs(mutablePtr, mutablePtr) + takesCharPtrs(id(mutablePtr), id(mutablePtr)) + takesCharPtrs(id(mutablePtr), optID(mutablePtr)) + + // Make sure this is ambiguous. + ptr.foo(chars) // expected-error {{ambiguous use of 'foo'}} +} diff --git a/unittests/Sema/ConstraintSystemDumpTests.cpp b/unittests/Sema/ConstraintSystemDumpTests.cpp index 30b137a433432..b650b46aa3452 100644 --- a/unittests/Sema/ConstraintSystemDumpTests.cpp +++ b/unittests/Sema/ConstraintSystemDumpTests.cpp @@ -33,7 +33,7 @@ TEST_F(SemaTest, DumpConstraintSystemBasic) { TupleType::get({Type(t0), Type(t1)}, Context), emptyLoc)); std::string expectedOutput = - R"(Score: + R"(Score: Type Variables: $T0 [can bind to: lvalue] [adjacent to: $T1, $T2] [potential bindings: ] @ locator@ [] $T1 [adjacent to: $T0, $T2] [potential bindings: ] @ locator@ []