Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate class hierarchy in GadtConstraint #16194

Merged
merged 3 commits into from Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Expand Up @@ -814,7 +814,7 @@ object Contexts {
.updated(notNullInfosLoc, Nil)
.updated(compilationUnitLoc, NoCompilationUnit)
searchHistory = new SearchRoot
gadt = EmptyGadtConstraint
gadt = GadtConstraint.empty
}

@sharable object NoContext extends Context((null: ContextBase | Null).uncheckedNN) {
Expand Down
163 changes: 53 additions & 110 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Expand Up @@ -10,77 +10,25 @@ import util.{SimpleIdentitySet, SimpleIdentityMap}
import collection.mutable
import printing._

import scala.annotation.internal.sharable
object GadtConstraint:
def apply(): GadtConstraint = empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a "canonical" empty GadtConstraint, to avoid allocating an extra object every time we create a new constraint?

def empty: GadtConstraint =
new ProperGadtConstraint(OrderingConstraint.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, false)

/** Represents GADT constraints currently in scope */
sealed abstract class GadtConstraint extends Showable {
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
def bounds(sym: Symbol)(using Context): TypeBounds | Null

/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
*
* @note this performs subtype checks between ordered symbols.
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
*/
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null

/** Is `sym1` ordered to be less than `sym2`? */
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean

/** Add symbols to constraint, correctly handling inter-dependencies.
*
* @see [[ConstraintHandling.addToConstraint]]
*/
def addToConstraint(syms: List[Symbol])(using Context): Boolean
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)

/** Further constrain a symbol already present in the constraint. */
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean

/** Is the symbol registered in the constraint?
*
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
*/
def contains(sym: Symbol)(using Context): Boolean

/** GADT constraint narrows bounds of at least one variable */
def isNarrowing: Boolean

/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type

def symbols: List[Symbol]

def fresh: GadtConstraint

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
def restore(other: GadtConstraint): Unit

/** Provides more information than toText, by showing the underlying Constraint details. */
def debugBoundsDescription(using Context): String
}

final class ProperGadtConstraint private(
sealed trait GadtConstraint (
private var myConstraint: Constraint,
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
private var wasConstrained: Boolean
) extends GadtConstraint with ConstraintHandling {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
) extends Showable {
this: ConstraintHandling =>

def this() = this(
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty),
mapping = SimpleIdentityMap.empty,
reverseMapping = SimpleIdentityMap.empty,
wasConstrained = false
)
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

/** Exposes ConstraintHandling.subsumes */
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
def extractConstraint(g: GadtConstraint) = g match {
case s: ProperGadtConstraint => s.constraint
case EmptyGadtConstraint => OrderingConstraint.empty
}
def extractConstraint(g: GadtConstraint) = g.constraint
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
}

Expand All @@ -89,7 +37,12 @@ final class ProperGadtConstraint private(
// the case where they're valid, so no approximating is needed.
rawBound

override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
/** Add symbols to constraint, correctly handling inter-dependencies.
*
* @see [[ConstraintHandling.addToConstraint]]
*/
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
def addToConstraint(params: List[Symbol])(using Context): Boolean = {
import NameKinds.DepParamName

val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
Expand Down Expand Up @@ -138,7 +91,8 @@ final class ProperGadtConstraint private(
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
}

override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
/** Further constrain a symbol already present in the constraint. */
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
case tv: TypeVar =>
val inst = constraint.instType(tv)
Expand Down Expand Up @@ -179,10 +133,16 @@ final class ProperGadtConstraint private(
result
}

override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
/** Is `sym1` ordered to be less than `sym2`? */
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin)

override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
*
* @note this performs subtype checks between ordered symbols.
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
*/
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
mapping(sym) match {
case null => null
// TODO: Improve flow typing so that ascription becomes redundant
Expand All @@ -191,7 +151,8 @@ final class ProperGadtConstraint private(
// .ensuring(containsNoInternalTypes(_))
}

override def bounds(sym: Symbol)(using Context): TypeBounds | Null =
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
def bounds(sym: Symbol)(using Context): TypeBounds | Null =
mapping(sym) match {
case null => null
// TODO: Improve flow typing so that ascription becomes redundant
Expand All @@ -202,11 +163,17 @@ final class ProperGadtConstraint private(
//.ensuring(containsNoInternalTypes(_))
}

override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null
/** Is the symbol registered in the constraint?
*
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
*/
def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null

/** GADT constraint narrows bounds of at least one variable */
def isNarrowing: Boolean = wasConstrained

override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = {
val res =
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
case tpr: TypeParamRef =>
Expand All @@ -220,23 +187,16 @@ final class ProperGadtConstraint private(
res
}

override def symbols: List[Symbol] = mapping.keys

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
reverseMapping,
wasConstrained
)

def restore(other: GadtConstraint): Unit = other match {
case other: ProperGadtConstraint =>
this.myConstraint = other.myConstraint
this.mapping = other.mapping
this.reverseMapping = other.reverseMapping
this.wasConstrained = other.wasConstrained
case _ => ;
}
def symbols: List[Symbol] = mapping.keys

def fresh: GadtConstraint = new ProperGadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
def restore(other: GadtConstraint): Unit =
this.myConstraint = other.myConstraint
this.mapping = other.mapping
this.reverseMapping = other.reverseMapping
this.wasConstrained = other.wasConstrained

// ---- Protected/internal -----------------------------------------------

Expand Down Expand Up @@ -294,30 +254,13 @@ final class ProperGadtConstraint private(

override def toText(printer: Printer): Texts.Text = printer.toText(this)

override def debugBoundsDescription(using Context): String = i"$this\n$constraint"
/** Provides more information than toText, by showing the underlying Constraint details. */
def debugBoundsDescription(using Context): String = i"$this\n$constraint"
}

@sharable object EmptyGadtConstraint extends GadtConstraint {
override def bounds(sym: Symbol)(using Context): TypeBounds | Null = null
override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = null

override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")

override def isNarrowing: Boolean = false

override def contains(sym: Symbol)(using Context) = false

override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")

override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")

override def toText(printer: Printer): Texts.Text = printer.toText(this)
override def debugBoundsDescription(using Context): String = i"$this"
}
private class ProperGadtConstraint (
myConstraint: Constraint,
mapping: SimpleIdentityMap[Symbol, TypeVar],
reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
wasConstrained: Boolean,
) extends ConstraintHandling with GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Expand Up @@ -1830,11 +1830,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
val preGadt = ctx.gadt.fresh

def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
subsumes(left, right, preConstraint) && preGadt.match
case preGadt: ProperGadtConstraint =>
preGadt.subsumes(leftGadt, rightGadt, preGadt)
case _ =>
true
subsumes(left, right, preConstraint) && preGadt.subsumes(leftGadt, rightGadt, preGadt)

if op1 then
val op1Constraint = constraint
Expand Down
12 changes: 5 additions & 7 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Expand Up @@ -693,13 +693,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
finally
ctx.typerState.constraint = savedConstraint

def toText(g: GadtConstraint): Text = g match
case EmptyGadtConstraint => "EmptyGadtConstraint"
case g: ProperGadtConstraint =>
val deps = for sym <- g.symbols yield
val bound = g.fullBounds(sym).nn
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close
def toText(g: GadtConstraint): Text =
val deps = for sym <- g.symbols yield
val bound = g.fullBounds(sym).nn
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close

def plain: PlainPrinter = this

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Expand Up @@ -3774,7 +3774,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
adaptToSubType(wtp)
case CompareResult.OKwithGADTUsed
if pt.isValueType
&& !inContext(ctx.fresh.setGadt(EmptyGadtConstraint)) {
&& !inContext(ctx.fresh.setGadt(GadtConstraint.empty)) {
val res = (tree.tpe.widenExpr frozen_<:< pt)
if res then
// we overshot; a cast is not needed, after all.
Expand Down