Skip to content

Commit

Permalink
Pattern match support in checking global objects (#18127)
Browse files Browse the repository at this point in the history
Pattern match in checking global objects
  • Loading branch information
liufengyun committed Jul 14, 2023
2 parents ca29cdc + 4cfcacf commit 04eae14
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 31 deletions.
215 changes: 201 additions & 14 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import core.*
import Contexts.*
import Symbols.*
import Types.*
import Denotations.Denotation
import StdNames.*
import Names.TermName
import NameKinds.OuterSelectName
import NameKinds.SuperAccessorName

import ast.tpd.*
import util.SourcePosition
import util.{ SourcePosition, NoSourcePosition }
import config.Printers.init as printer
import reporting.StoreReporter
import reporting.trace as log
import typer.Applications.*

import Errors.*
import Trace.*
Expand Down Expand Up @@ -249,7 +252,7 @@ object Objects:
val joinedTrace = data.pendingTraces.slice(index + 1, data.checkingObjects.size).foldLeft(pendingTrace) { (a, acc) => acc ++ a }
val callTrace = Trace.buildStacktrace(joinedTrace, "Calling trace:\n")
val cycle = data.checkingObjects.slice(index, data.checkingObjects.size)
val pos = clazz.defTree
val pos = clazz.defTree.sourcePos.focus
report.warning("Cyclic initialization: " + cycle.map(_.klass.show).mkString(" -> ") + " -> " + clazz.show + ". " + callTrace, pos)
end if
data.checkingObjects(index)
Expand Down Expand Up @@ -834,11 +837,10 @@ object Objects:

/** Handle local variable definition, `val x = e` or `var x = e`.
*
* @param ref The value for `this` where the variable is defined.
* @param sym The symbol of the variable.
* @param value The value of the initializer.
*/
def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
if sym.is(Flags.Mutable) then
val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject)
Env.setLocalVar(sym, addr)
Expand Down Expand Up @@ -870,9 +872,6 @@ object Objects:
case _ =>
report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". Calling trace:\n" + Trace.show, Trace.position)
Bottom
else if sym.isPatternBound then
// TODO: handle patterns
Cold
else
given Env.Data = env
// Assume forward reference check is doing a good job
Expand Down Expand Up @@ -1113,11 +1112,9 @@ object Objects:
else
eval(arg, thisV, klass)

case Match(selector, cases) =>
eval(selector, thisV, klass)
// TODO: handle pattern match properly
report.warning("[initChecker] Pattern match is skipped. Trace:\n" + Trace.show, expr)
Bottom
case Match(scrutinee, cases) =>
val scrutineeValue = eval(scrutinee, thisV, klass)
patternMatch(scrutineeValue, cases, thisV, klass)

case Return(expr, from) =>
Returns.handle(from.symbol, eval(expr, thisV, klass))
Expand Down Expand Up @@ -1151,7 +1148,7 @@ object Objects:
// local val definition
val rhs = eval(vdef.rhs, thisV, klass)
val sym = vdef.symbol
initLocal(thisV.asInstanceOf[Ref], vdef.symbol, rhs)
initLocal(vdef.symbol, rhs)
Bottom

case ddef : DefDef =>
Expand All @@ -1173,6 +1170,196 @@ object Objects:
Bottom
}

/** Evaluate the cases against the scrutinee value.
*
* It returns the scrutinee in most cases. The main effect of the function is for its side effects of adding bindings
* to the environment.
*
* See https://docs.scala-lang.org/scala3/reference/changed-features/pattern-matching.html
*
* @param scrutinee The abstract value of the scrutinee.
* @param cases The cases to match.
* @param thisV The value for `C.this` where `C` is represented by `klass`.
* @param klass The enclosing class where the type `tp` is located.
*/
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
// expected member types for `unapplySeq`
def lengthType = ExprType(defn.IntType)
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))

def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
eval(caseDef.body, thisV, klass)

/** Abstract evaluation of patterns.
*
* It augments the local environment for bound pattern variables. As symbols are globally
* unique, we can put them in a single environment.
*
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
*/
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
val trace2 = Trace.trace.add(pat)
pat match
case Alternative(pats) =>
for pat <- pats do evalPattern(scrutinee, pat)
scrutinee

case bind @ Bind(_, pat) =>
val value = evalPattern(scrutinee, pat)
initLocal(bind.symbol, value)
scrutinee

case UnApply(fun, implicits, pats) =>
given Trace = trace2

val fun1 = funPart(fun)
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiver = fun1 match
case ident: Ident =>
evalType(funRef.prefix, thisV, klass)
case select: Select =>
eval(select.qualifier, thisV, klass)

val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
// TODO: implicit values may appear before and/or after the scrutinee parameter.
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)

if fun.symbol.name == nme.unapplySeq then
var resultTp = unapplyResTp
var elemTp = unapplySeqTypeElemTp(resultTp)
var arity = productArity(resultTp, NoSourcePosition)
var needsGet = false
if (!elemTp.exists && arity <= 0) {
needsGet = true
resultTp = resultTp.select(nme.get).finalResultType
elemTp = unapplySeqTypeElemTp(resultTp.widen)
arity = productSelectorTypes(resultTp, NoSourcePosition).size
}

var resToMatch = unapplyRes

if needsGet then
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
end if

if elemTp.exists then
// sequence match
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
else
// product sequence match
val selectors = productSelectors(resultTp)
assert(selectors.length <= pats.length)
selectors.init.zip(pats).map { (sel, pat) =>
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
val seqPats = pats.drop(selectors.length - 1)
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
end if

else
// distribute unapply to patterns
if isProductMatch(unapplyResTp, pats.length) then
// product match
val selectors = productSelectors(unapplyResTp)
assert(selectors.length == pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
else if unapplyResTp <:< defn.BooleanType then
// Boolean extractor, do nothing
()
else
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
val getRes = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
if pats.length == 1 then
// single match
evalPattern(getRes, pats.head)
else
val getResTp = getDenot.info.finalResultType
val selectors = productSelectors(getResTp).take(pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
end if
end if
end if
scrutinee

case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
scrutinee

case Typed(pat, _) =>
evalPattern(scrutinee, pat)

case tree =>
// For all other trees, the semantics is normal.
eval(tree, thisV, klass)

end evalPattern

/**
* Evaluate a sequence value against sequence patterns.
*/
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree])(using Trace): Unit =
// call .lengthCompare or .length
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
if lengthCompareDenot.exists then
call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
else
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
end if

// call .apply
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)

if isWildcardStarArg(pats.last) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
end if
else
// no patterns like `xs*`
for pat <- pats do evalPattern(applyRes, pat)
end evalSeqPatterns


cases.map(evalCase).join
end patternMatch

/** Handle semantics of leaf nodes
*
* For leaf nodes, their semantics is determined by their types.
Expand Down Expand Up @@ -1231,7 +1418,7 @@ object Objects:
resolveThis(tref.classSymbol.asClass, thisV, klass)

case _ =>
throw new Exception("unexpected type: " + tp)
throw new Exception("unexpected type: " + tp + ", Trace:\n" + Trace.show)
}

/** Evaluate arguments of methods and constructors */
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/init/Trace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ object Trace:
val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn)
i"$code\t$loc"
else
tree.show
tree match
case defDef: DefTree =>
// The definition can be huge, avoid printing the whole definition.
defDef.symbol.show
case _ =>
tree.show
val positionMarkerLine =
if pos.exists && pos.source.exists then
positionMarker(pos)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ object Util:
opaque type Arg = Tree | ByNameArg
case class ByNameArg(tree: Tree)

object Arg:
def apply(tree: Tree): Arg = tree

extension (arg: Arg)
def isByName = arg.isInstanceOf[ByNameArg]
def tree: Tree = arg match
Expand Down
22 changes: 10 additions & 12 deletions tests/init-global/neg/global-cycle1.check
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
-- Error: tests/init-global/neg/global-cycle1.scala:1:7 ----------------------------------------------------------------
1 |object A { // error
|^
|Cyclic initialization: object A -> object B -> object A. Calling trace:
|-> object A { // error [ global-cycle1.scala:1 ]
| ^
|-> val a: Int = B.b [ global-cycle1.scala:2 ]
| ^
|-> object B { [ global-cycle1.scala:5 ]
| ^
|-> val b: Int = A.a // error [ global-cycle1.scala:6 ]
| ^
2 | val a: Int = B.b
3 |}
| ^
| Cyclic initialization: object A -> object B -> object A. Calling trace:
| -> object A { // error [ global-cycle1.scala:1 ]
| ^
| -> val a: Int = B.b [ global-cycle1.scala:2 ]
| ^
| -> object B { [ global-cycle1.scala:5 ]
| ^
| -> val b: Int = A.a // error [ global-cycle1.scala:6 ]
| ^
-- Error: tests/init-global/neg/global-cycle1.scala:6:17 ---------------------------------------------------------------
6 | val b: Int = A.a // error
| ^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/init-global/neg/global-cycle6.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
object A { // error
val n: Int = B.m
class Inner {
println(n)
println(n) // error
}
}

Expand Down
11 changes: 11 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Error: tests/init-global/neg/patmat-unapplySeq.scala:8:32 -----------------------------------------------------------
8 | def apply(i: Int): Box = array(i) // error
| ^^^^^^^^
|Reading mutable state of object A during initialization of object B.
|Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. Calling trace:
|-> object B: [ patmat-unapplySeq.scala:15 ]
| ^
|-> case A(b) => [ patmat-unapplySeq.scala:17 ]
| ^^^^
|-> def apply(i: Int): Box = array(i) // error [ patmat-unapplySeq.scala:8 ]
| ^^^^^^^^
17 changes: 17 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
object A:
class Box(var x: Int)

val array: Array[Box] = new Array(1)
array(0) = new Box(10)

def length: Int = array.length
def apply(i: Int): Box = array(i) // error
def drop(n: Int): Seq[Box] = array.toSeq
def toSeq: Seq[Box] = array.toSeq

def unapplySeq(array: Array[Box]): A.type = this


object B:
A.array match
case A(b) =>
17 changes: 17 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
object A:
class Box(var x: Int)

val array: Array[Box] = new Array(1)
array(0) = new Box(10)

def length: Int = array.length
def apply(i: Int): Box = array(i) // error
def drop(n: Int): Seq[Box] = array.toSeq
def toSeq: Seq[Box] = array.toSeq

def unapplySeq(array: Array[Box]): A.type = this


object B:
A.array match
case A(b*) =>
Loading

0 comments on commit 04eae14

Please sign in to comment.