Skip to content

Commit

Permalink
Make HOAS Quote pattern match with def method capture
Browse files Browse the repository at this point in the history
closes #17105
  • Loading branch information
zeptometer committed Jun 19, 2023
1 parent a68568c commit 3af515d
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 28 deletions.
101 changes: 73 additions & 28 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.util.optional
import dotty.tools.dotc.core.Definitions

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
Expand Down Expand Up @@ -259,12 +260,34 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>

/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
* e.g.
* g: (Int) => Int
* => {
* def $anonfun(y: Int): Int = g(y)
* closure($anonfun)
* }
*
* f: (using Int) => Int
* => f(using x)
* This function restores the symbol of the original method from
* the eta-expanded function.
*/
def getCapturedIdent(arg: Tree)(using Context): Ident =
arg match
case id: Ident => id
case Apply(fun, _) => getCapturedIdent(fun)
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
case Typed(expr, _) => getCapturedIdent(expr)

val env = summon[Env]
val capturedArgs = args.map(_.symbol)
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
val capturedIds = args.map(getCapturedIdent)
val capturedSymbols = capturedIds.map(_.symbol)
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
case _ => notMatched
}

Expand Down Expand Up @@ -394,19 +417,34 @@ object QuoteMatcher {
case scrutinee @ DefDef(_, paramss1, tpt1, _) =>
pattern match
case pattern @ DefDef(_, paramss2, tpt2, _) =>
def rhsEnv: Env =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.zip(clause2)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] =
(sctype, pttype) match
case (sctpe: MethodType, pttpe: MethodType) =>
if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
matchErasedParams(sctpe.resType, pttpe.resType)
else
notMatched
case _ => matched

def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
(scparamss, ptparamss) match {
case (scparams :: screst, ptparams :: ptrest) =>
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
(resEnv, mr1 &&& mrrest)
case (Nil, Nil) => (summon[Env], matched)
case _ => notMatched
}

val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)

ematch
&&& pmatch
&&& withEnv(defEnv)(tpt1 =?= tpt2)
&&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
case _ => notMatched

case Closure(_, _, tpt1) =>
Expand Down Expand Up @@ -497,10 +535,14 @@ object QuoteMatcher {
*
* @param tree Scrutinee sub-tree that matched
* @param patternTpe Type of the pattern hole (from the pattern)
* @param args HOAS arguments (from the pattern)
* @param argIds Identifiers of HOAS arguments (from the pattern)
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)

/** The Definitions object */
def defn(using Context): Definitions = ctx.definitions

/** Return the expression that was extracted from a hole.
*
Expand All @@ -513,19 +555,22 @@ object QuoteMatcher {
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
Expand All @@ -534,7 +579,7 @@ object QuoteMatcher {
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched: optional[MatchingExprs] =
private inline def notMatched[T]: optional[T] =
optional.break()

private inline def matched: MatchingExprs =
Expand All @@ -543,8 +588,8 @@ object QuoteMatcher {
private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))

extension (self: MatchingExprs)
/** Concatenates the contents of two successful matchings */
Expand Down
3 changes: 3 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
case erased: [erased case]
case erased nested: c
case erased nested 2: d
25 changes: 25 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import scala.quoted.*

inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
body match
// Erased Types
case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } =>
Expr("This case should not match")
case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } =>
'{ $a((erased z: String) => "[erased case]") }
case '{
def erasedfn(a: String, b: String)(c: String, d: String): String = a
$y(erasedfn): String
} => Expr("This should not match")
case '{
def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a
$y(erasedfn): String
} =>
'{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) }
case '{
def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a
$y(erasedfn): String
} =>
'{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) }
case _ => Expr("not matched")
10 changes: 10 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@main def Test: Unit =
println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")})
println("case erased nested: " + testExpr {
def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p
erasedfn2("a", "b")("c", "d")
})
println("case erased nested 2: " + testExpr {
def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p
erasedfn2("a", "b")("c", "d")
})
8 changes: 8 additions & 0 deletions tests/run-macros/i17105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
case single: [1st case] arg1 outside
case no-param-method (will be eta-expanded): [1st case] placeholder 2
case curried: [2nd case] arg1, arg2 outside
case methods from outer scope: [1st case] arg1 outer-method
case refinement: Hoe got 1
case dependent: 1
case dependent2: 1
case dependent3: 1
15 changes: 15 additions & 0 deletions tests/run-macros/i17105/Lib1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

// Test case for dependent types
trait DSL {
type N
def toString(n: N): String
val zero: N
def next(n: N): N
}

object IntDSL extends DSL {
type N = Int
def toString(n: N): String = n.toString()
val zero = 0
def next(n: N): N = n + 1
}
34 changes: 34 additions & 0 deletions tests/run-macros/i17105/Macro_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import scala.quoted.*
import language.experimental.erasedDefinitions

inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
body match
case '{ def g(y: String) = "placeholder" + y; $a(g): String } =>
'{ $a((z: String) => s"[1st case] ${z}") }
case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } =>
'{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") }
// Refined Types
case '{
type t
def refined(a: `t`): String = $x(a): String
$y(refined): String
} =>
'{ $y($x) }
// Dependent Types
case '{
def p(dsl: DSL): dsl.N = dsl.zero
$y(p): String
} =>
'{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) }
case '{
def p(dsl: DSL)(a: dsl.N): dsl.N = a
$y(p): String
} =>
'{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) }
case '{
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
$y(p): String
} =>
'{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) }
case _ => Expr("not matched")
23 changes: 23 additions & 0 deletions tests/run-macros/i17105/Test_3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import reflect.Selectable.reflectiveSelectable

class Hoe { def f(x: Int): String = s"Hoe got ${x}" }

@main def Test: Unit =
println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" })
println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") })
println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" })
def outer() = " outer-method"
println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() })
println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) })
println("case dependent: " + testExpr {
def p(a: DSL): a.N = a.zero
IntDSL.toString(p(IntDSL))
})
println("case dependent2: " + testExpr {
def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c
IntDSL.toString(p(IntDSL)(IntDSL.zero))
})
println("case dependent3: " + testExpr {
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
IntDSL.toString(p(IntDSL)(IntDSL))
})

0 comments on commit 3af515d

Please sign in to comment.