Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat Refinements more like AndTypes #12317

Merged
merged 4 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 74 additions & 57 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
private inline def inFrozenGadtAndConstraint[T](inline op: T): T =
inFrozenGadtIf(true)(inFrozenConstraint(op))

extension (sym: Symbol)
private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean =
val bounds = gadtBounds(sym)
bounds != null && op(bounds)

protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
val savedApprox = approx
val savedLeftRoot = leftRoot
Expand Down Expand Up @@ -465,19 +470,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
case _ => true

// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
case tp: AndType => true
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
case _ => false

widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1) && inFrozenGadt(recur(tp1.join, tp2))
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.

case tp1: MatchType =>
val reduced = tp1.reduced
if (reduced.exists) recur(reduced, tp2) else thirdTry
Expand All @@ -489,11 +490,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
case info2: TypeBounds =>
def compareGADT: Boolean = {
val gbounds2 = gadtBounds(tp2.symbol)
(gbounds2 != null) &&
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
(tp1 match {
def compareGADT: Boolean =
tp2.symbol.onGadtBounds(gbounds2 =>
isSubTypeWhenFrozen(tp1, gbounds2.lo)
|| tp1.match
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
// Note: since we approximate constrained types only with their non-param bounds,
// we need to manually handle the case when we're comparing two constrained types,
Expand All @@ -502,10 +502,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// comparing two constrained types, and that case will be handled here first.
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
case _ => false
}) ||
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
{ isBottom(tp1) || GADTusage(tp2.symbol) }
}
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false))
&& (isBottom(tp1) || GADTusage(tp2.symbol))

isSubApproxHi(tp1, info2.lo) || compareGADT || tryLiftedToThis2 || fourthTry

case _ =>
Expand Down Expand Up @@ -559,31 +558,35 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case tp2: TypeParamRef =>
compareTypeParamRef(tp2)
case tp2: RefinedType =>
def compareRefinedSlow: Boolean = {
def compareRefinedSlow: Boolean =
val name2 = tp2.refinedName
recur(tp1, tp2.parent) &&
(name2 == nme.WILDCARD || hasMatchingMember(name2, tp1, tp2))
}
def compareRefined: Boolean = {
recur(tp1, tp2.parent)
&& (name2 == nme.WILDCARD || hasMatchingMember(name2, tp1, tp2))

def compareRefined: Boolean =
val tp1w = tp1.widen
val skipped2 = skipMatching(tp1w, tp2)
if ((skipped2 eq tp2) || !Config.fastPathForRefinedSubtype)
tp1 match {
case tp1: AndType =>
// Delay calling `compareRefinedSlow` because looking up a member
// of an `AndType` can lead to a cascade of subtyping checks
// This twist is needed to make collection/generic/ParFactory.scala compile
fourthTry || compareRefinedSlow
case tp1: HKTypeLambda =>
// HKTypeLambdas do not have members.
fourthTry
case _ =>
compareRefinedSlow || fourthTry
}
if (skipped2 eq tp2) || !Config.fastPathForRefinedSubtype then
if containsAnd(tp1) then
tp2.parent match
case _: RefinedType | _: AndType =>
// maximally decompose RHS to limit the bad effects of the `either` that is necessary
// since LHS contains an AndType
recur(tp1, decomposeRefinements(tp2, Nil))
case _ =>
// Delay calling `compareRefinedSlow` because looking up a member
// of an `AndType` can lead to a cascade of subtyping checks
// This twist is needed to make collection/generic/ParFactory.scala compile
fourthTry || compareRefinedSlow
else if tp1.isInstanceOf[HKTypeLambda] then
// HKTypeLambdas do not have members.
fourthTry
else
compareRefinedSlow || fourthTry
else // fast path, in particular for refinements resulting from parameterization.
isSubRefinements(tp1w.asInstanceOf[RefinedType], tp2, skipped2) &&
recur(tp1, skipped2)
}

compareRefined
case tp2: RecType =>
def compareRec = tp1.safeDealias match {
Expand Down Expand Up @@ -751,13 +754,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case tp1: TypeRef =>
tp1.info match {
case TypeBounds(_, hi1) =>
def compareGADT = {
val gbounds1 = gadtBounds(tp1.symbol)
(gbounds1 != null) &&
(isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
narrowGADTBounds(tp1, tp2, approx, isUpper = true)) &&
{ tp2.isAny || GADTusage(tp1.symbol) }
}
def compareGADT =
tp1.symbol.onGadtBounds(gbounds1 =>
isSubTypeWhenFrozen(gbounds1.hi, tp2)
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
&& (tp2.isAny || GADTusage(tp1.symbol))

isSubType(hi1, tp2, approx.addLow) || compareGADT || tryLiftedToThis1
case _ =>
def isNullable(tp: Type): Boolean = tp.widenDealias match {
Expand Down Expand Up @@ -1033,17 +1035,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

var touchedGADTs = false
var gadtIsInstantiated = false
def byGadtBounds(sym: Symbol, tp: Type, fromAbove: Boolean): Boolean = {
touchedGADTs = true
val b = gadtBounds(sym)
def boundsDescr = if b == null then "null" else b.show
b != null && inFrozenGadt {
if fromAbove then isSubType(b.hi, tp) else isSubType(tp, b.lo)
} && {
gadtIsInstantiated = b.isInstanceOf[TypeAlias]
true
}
}

extension (sym: Symbol)
inline def byGadtBounds(inline op: TypeBounds => Boolean): Boolean =
touchedGADTs = true
sym.onGadtBounds(
b => op(b) && { gadtIsInstantiated = b.isInstanceOf[TypeAlias]; true })

def byGadtOrdering: Boolean =
ctx.gadt.contains(tycon1sym)
Expand All @@ -1052,8 +1049,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

val res = (
tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix)
|| byGadtBounds(tycon1sym, tycon2, fromAbove = true)
|| byGadtBounds(tycon2sym, tycon1, fromAbove = false)
|| tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2))
|| tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo))
|| byGadtOrdering
) && {
// There are two cases in which we can assume injectivity.
Expand Down Expand Up @@ -1691,6 +1688,26 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else op2
end necessaryEither

/** Decompose into conjunction of types each of which has only a single refinement */
def decomposeRefinements(tp: Type, refines: List[(Name, Type)]): Type = tp match
case RefinedType(parent, rname, rinfo) =>
decomposeRefinements(parent, (rname, rinfo) :: refines)
case AndType(tp1, tp2) =>
AndType(decomposeRefinements(tp1, refines), decomposeRefinements(tp2, refines))
case _ =>
refines.map(RefinedType(tp, _, _): Type).reduce(AndType(_, _))

/** Can comparing this type on the left lead to an either? This is the case if
* the type is and AndType or contains embedded occurrences of AndTypes
*/
def containsAnd(tp: Type): Boolean = tp match
case tp: AndType => true
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
case tp: TypeParamRef => containsAnd(bounds(tp).hi)
case tp: TypeRef => containsAnd(tp.info.hiBound) || tp.symbol.onGadtBounds(gbounds => containsAnd(gbounds.hi))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, interesting. We always compare with GADT bounds in frozen constraint, but that doesn't end up mattering, right? That is: do I understand correctly that even if we're not inferring constraints, we need to break down structural types on the RHS to properly check if there's subtyping between the left- and right-hand sides?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the same reasoning applies: If we end up comparing GADT constraints we need necessaryEither, and that means we need to decompose the RHS as much as possible so that necesssaryEither is sound.

case tp: TypeProxy => containsAnd(tp.superType)
case _ => false

/** Does type `tp1` have a member with name `name` whose normalized type is a subtype of
* the normalized type of the refinement `tp2`?
* Normalization is as follows: If `tp2` contains a skolem to its refinement type,
Expand Down
23 changes: 23 additions & 0 deletions tests/pos/i12306.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class Record(elems: Map[String, Any]) extends Selectable:
val fields = elems.toMap
def selectDynamic(name: String): Any = fields(name)
object Record:
def apply(elems: Map[String, Any]): Record = new Record(elems)
extension [A <: Record] (a:A) {
def join[B <: Record] (b:B): A & B = {
Record(a.fields ++ b.fields).asInstanceOf[A & B]
}
}

type Person = Record { val name: String; val age: Int }
type Child = Record { val parent: String }
type PersonAndChild = Record { val name: String; val age: Int; val parent: String }

@main def hello = {
val person = Record(Map("name" -> "Emma", "age" -> 42)).asInstanceOf[Person]
val child = Record(Map("parent" -> "Alice")).asInstanceOf[Child]
val personAndChild = person.join(child)

val v1: PersonAndChild = personAndChild
val v2: PersonAndChild = person.join(child)
}