Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pattern match support in checking global objects #18127

Merged
merged 9 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 189 additions & 12 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import core.*
import Contexts.*
import Symbols.*
import Types.*
import Denotations.Denotation
import StdNames.*
import Names.TermName
import NameKinds.OuterSelectName
import NameKinds.SuperAccessorName

import ast.tpd.*
import util.SourcePosition
import util.{ SourcePosition, NoSourcePosition }
import config.Printers.init as printer
import reporting.StoreReporter
import reporting.trace as log
import typer.Applications.*

import Errors.*
import Trace.*
Expand Down Expand Up @@ -834,11 +837,10 @@ object Objects:

/** Handle local variable definition, `val x = e` or `var x = e`.
*
* @param ref The value for `this` where the variable is defined.
* @param sym The symbol of the variable.
* @param value The value of the initializer.
*/
def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
if sym.is(Flags.Mutable) then
val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject)
Env.setLocalVar(sym, addr)
Expand Down Expand Up @@ -870,9 +872,6 @@ object Objects:
case _ =>
report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". Calling trace:\n" + Trace.show, Trace.position)
Bottom
else if sym.isPatternBound then
// TODO: handle patterns
Cold
else
given Env.Data = env
// Assume forward reference check is doing a good job
Expand Down Expand Up @@ -1113,11 +1112,9 @@ object Objects:
else
eval(arg, thisV, klass)

case Match(selector, cases) =>
eval(selector, thisV, klass)
// TODO: handle pattern match properly
report.warning("[initChecker] Pattern match is skipped. Trace:\n" + Trace.show, expr)
Bottom
case Match(scrutinee, cases) =>
val scrutineeValue = eval(scrutinee, thisV, klass)
patternMatch(scrutineeValue, cases, thisV, klass)

case Return(expr, from) =>
Returns.handle(from.symbol, eval(expr, thisV, klass))
Expand Down Expand Up @@ -1151,7 +1148,7 @@ object Objects:
// local val definition
val rhs = eval(vdef.rhs, thisV, klass)
val sym = vdef.symbol
initLocal(thisV.asInstanceOf[Ref], vdef.symbol, rhs)
initLocal(vdef.symbol, rhs)
Bottom

case ddef : DefDef =>
Expand All @@ -1173,6 +1170,186 @@ object Objects:
Bottom
}

/** Evaluate the cases against the scrutinee value.
*
* @param scrutinee The abstract value of the scrutinee.
* @param cases The cases to match.
* @param thisV The value for `C.this` where `C` is represented by `klass`.
* @param klass The enclosing class where the type `tp` is located.
*/
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
// expected member types for `unapplySeq`
def lengthType = ExprType(defn.IntType)
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))

def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
eval(caseDef.body, thisV, klass)

/** Abstract evaluation of patterns.
*
* It augments the local environment for bound pattern variables. As symbols are globally
* unique, we can put them in a single environment.
*
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
liufengyun marked this conversation as resolved.
Show resolved Hide resolved
*/
liufengyun marked this conversation as resolved.
Show resolved Hide resolved
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
val trace2 = Trace.trace.add(pat)
pat match
case Alternative(pats) =>
for pat <- pats do evalPattern(scrutinee, pat)
scrutinee

case bind @ Bind(_, pat) =>
val value = evalPattern(scrutinee, pat)
initLocal(bind.symbol, value)
scrutinee

case UnApply(fun, implicits, pats) =>
given Trace = trace2

val fun1 = funPart(fun)
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiver = evalType(funRef.prefix, thisV, klass)
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
// TODO: implicit values may appear before and/or after the scrutinee parameter.
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)

if fun.symbol.name == nme.unapplySeq then
var resultTp = unapplyResTp
var elemTp = unapplySeqTypeElemTp(resultTp)
var arity = productArity(resultTp, NoSourcePosition)
var needsGet = false
if (!elemTp.exists && arity <= 0) {
needsGet = true
resultTp = resultTp.select(nme.get).finalResultType
elemTp = unapplySeqTypeElemTp(resultTp.widen)
arity = productSelectorTypes(resultTp, NoSourcePosition).size
}

var resToMatch = unapplyRes

if needsGet then
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
end if

if elemTp.exists then
// sequence match
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
else
// product sequence match
val selectors = productSelectors(resultTp)
assert(selectors.length <= pats.length)
selectors.init.zip(pats).map { (sel, pat) =>
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
val seqPats = pats.drop(selectors.length - 1)
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
end if

else
// distribute unapply to patterns
if isProductMatch(unapplyResTp, pats.length) then
// product match
val selectors = productSelectors(unapplyResTp)
assert(selectors.length == pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
else if unapplyResTp <:< defn.BooleanType then
// Boolean extractor, do nothing
()
else
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
val getRes = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
if pats.length == 1 then
// single match
evalPattern(getRes, pats.head)
else
val getResTp = getDenot.info.finalResultType
val selectors = productSelectors(getResTp).take(pats.length)
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}
end if
end if
end if
scrutinee

case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
scrutinee

case Typed(pat, _) =>
evalPattern(scrutinee, pat)

case tree =>
// For all other trees, the semantics is normal.
eval(tree, thisV, klass)

end evalPattern

/**
* Evaluate a sequence value against sequence patterns.
*/
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree])(using Trace): Unit =
// call .lengthCompare or .length
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
if lengthCompareDenot.exists then
call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
else
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
end if

// call .apply
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)

if isWildcardStarArg(pats.last) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
end if
else
// no patterns like `xs*`
for pat <- pats do evalPattern(applyRes, pat)
end evalSeqPatterns


cases.map(evalCase).join

liufengyun marked this conversation as resolved.
Show resolved Hide resolved

/** Handle semantics of leaf nodes
*
* For leaf nodes, their semantics is determined by their types.
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ object Util:
opaque type Arg = Tree | ByNameArg
case class ByNameArg(tree: Tree)

object Arg:
def apply(tree: Tree): Arg = tree

extension (arg: Arg)
def isByName = arg.isInstanceOf[ByNameArg]
def tree: Tree = arg match
Expand Down
11 changes: 11 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Error: tests/init-global/neg/patmat-unapplySeq.scala:8:32 -----------------------------------------------------------
8 | def apply(i: Int): Box = array(i) // error
| ^^^^^^^^
|Reading mutable state of object A during initialization of object B.
|Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. Calling trace:
|-> object B: [ patmat-unapplySeq.scala:15 ]
| ^
|-> case A(b) => [ patmat-unapplySeq.scala:17 ]
| ^^^^
|-> def apply(i: Int): Box = array(i) // error [ patmat-unapplySeq.scala:8 ]
| ^^^^^^^^
17 changes: 17 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
object A:
class Box(var x: Int)

val array: Array[Box] = new Array(1)
array(0) = new Box(10)

def length: Int = array.length
def apply(i: Int): Box = array(i) // error
def drop(n: Int): Seq[Box] = array.toSeq
def toSeq: Seq[Box] = array.toSeq

def unapplySeq(array: Array[Box]): A.type = this


object B:
A.array match
case A(b) =>
17 changes: 17 additions & 0 deletions tests/init-global/neg/patmat-unapplySeq2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
object A:
class Box(var x: Int)

val array: Array[Box] = new Array(1)
array(0) = new Box(10)

def length: Int = array.length
def apply(i: Int): Box = array(i) // error
def drop(n: Int): Seq[Box] = array.toSeq
def toSeq: Seq[Box] = array.toSeq

def unapplySeq(array: Array[Box]): A.type = this


object B:
A.array match
case A(b*) =>
36 changes: 36 additions & 0 deletions tests/init-global/neg/patmat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
object A: // error
val a: Option[Int] = Some(3)
a match
case Some(x) => println(x * 2 + B.a.size)
case None => println(0)

object B:
val a = 3 :: 4 :: Nil
a match
case x :: xs =>
println(x * 2)
if A.a.isEmpty then println(xs.size)
case Nil =>
println(0)

case class Box[T](value: T)
case class Holder[T](value: T)
object C:
(Box(5): Box[Int] | Holder[Int]) match
case Box(x) => x
case Holder(x) => x

(Box(5): Box[Int] | Holder[Int]) match
case box: Box[Int] => box.value
case holder: Holder[Int] => holder.value

val a: Int = Inner.b

object Inner: // error
val b: Int = 10

val foo: () => Int = () => C.a

(Box(foo): Box[() => Int] | Holder[Int]) match
case Box(f) => f()
case Holder(x) => x
14 changes: 14 additions & 0 deletions tests/init-global/pos/patmat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object A:
val a: Option[Int] = Some(3)
a match
case Some(x) => println(x * 2)
case None => println(0)

object B:
val a = 3 :: 4 :: Nil
a match
case x :: xs =>
println(x * 2)
println(xs.size)
case Nil =>
println(0)
Loading