Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HOAS Quote pattern match with def method capture #17567

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
97 changes: 69 additions & 28 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,34 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>

/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
* e.g.
* g: (Int) => Int
* => {
* def $anonfun(y: Int): Int = g(y)
* closure($anonfun)
* }
*
* f: (using Int) => Int
* => f(using x)
* This function restores the symbol of the original method from
* the eta-expanded function.
*/
def getCapturedIdent(arg: Tree)(using Context): Ident =
arg match
case id: Ident => id
case Apply(fun, _) => getCapturedIdent(fun)
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
case Typed(expr, _) => getCapturedIdent(expr)

val env = summon[Env]
val capturedArgs = args.map(_.symbol)
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
val capturedIds = args.map(getCapturedIdent)
val capturedSymbols = capturedIds.map(_.symbol)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getCapturedIdent could return the Symbol directly. This way, we can avoid this extra map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I intended to use capturedIds as a parameter to matchedOpen and we cannot omit it (we get compiler errors if we use args for matchedOpen instead).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to your other comment on i17105.check, it's likely I need to change this part anyway.

val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
case _ => notMatched
}

Expand Down Expand Up @@ -394,19 +416,34 @@ object QuoteMatcher {
case scrutinee @ DefDef(_, paramss1, tpt1, _) =>
pattern match
case pattern @ DefDef(_, paramss2, tpt2, _) =>
def rhsEnv: Env =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.zip(clause2)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
def matchErasedParams(sctype: Type, pttype: Type): optional[MatchingExprs] =
(sctype, pttype) match
case (sctpe: MethodType, pttpe: MethodType) =>
if sctpe.erasedParams.sameElements(pttpe.erasedParams) then
matchErasedParams(sctpe.resType, pttpe.resType)
else
notMatched
case _ => matched

def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
zeptometer marked this conversation as resolved.
Show resolved Hide resolved
(scparamss, ptparamss) match {
case (scparams :: screst, ptparams :: ptrest) =>
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
(resEnv, mr1 &&& mrrest)
case (Nil, Nil) => (summon[Env], matched)
case _ => notMatched
}

val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)

ematch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ematch is always Seq.empty and pmatch should also return Seq.empty on match.
Should I remove the part ematch &&& pmatch &&& to simplify logic?

&&& pmatch
&&& withEnv(defEnv)(tpt1 =?= tpt2)
&&& withEnv(defEnv)(scrutinee.rhs =?= pattern.rhs)
case _ => notMatched

case Closure(_, _, tpt1) =>
Expand Down Expand Up @@ -497,10 +534,11 @@ object QuoteMatcher {
*
* @param tree Scrutinee sub-tree that matched
* @param patternTpe Type of the pattern hole (from the pattern)
* @param args HOAS arguments (from the pattern)
* @param argIds Identifiers of HOAS arguments (from the pattern)
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)

/** Return the expression that was extracted from a hole.
*
Expand All @@ -513,19 +551,22 @@ object QuoteMatcher {
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args)
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
Expand All @@ -534,7 +575,7 @@ object QuoteMatcher {
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched: optional[MatchingExprs] =
private inline def notMatched[T]: optional[T] =
optional.break()

private inline def matched: MatchingExprs =
Expand All @@ -543,8 +584,8 @@ object QuoteMatcher {
private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))

extension (self: MatchingExprs)
/** Concatenates the contents of two successful matchings */
Expand Down
3 changes: 3 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
case erased: [erased case]
case erased nested: c
case erased nested 2: d
25 changes: 25 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import scala.quoted.*

inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
body match
// Erased Types
case '{ def erasedfn(y: String) = "placeholder"; $a(erasedfn): String } =>
Expr("This case should not match")
case '{ def erasedfn(erased y: String) = "placeholder"; $a(erasedfn): String } =>
'{ $a((erased z: String) => "[erased case]") }
case '{
def erasedfn(a: String, b: String)(c: String, d: String): String = a
$y(erasedfn): String
} => Expr("This should not match")
case '{
def erasedfn(a: String, erased b: String)(erased c: String, d: String): String = a
$y(erasedfn): String
} =>
'{ $y((a: String, erased b: String) => (erased c: String, d: String) => d) }
case '{
def erasedfn(a: String, erased b: String)(c: String, erased d: String): String = a
$y(erasedfn): String
} =>
'{ $y((a: String, erased b: String) => (c: String, erased d: String) => c) }
case _ => Expr("not matched")
10 changes: 10 additions & 0 deletions tests/run-custom-args/run-macros-erased/i17105/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@main def Test: Unit =
println("case erased: " + testExpr { def erasedfn1(erased x: String) = "placeholder"; erasedfn1("arg1")})
println("case erased nested: " + testExpr {
def erasedfn2(p: String, erased q: String)(r: String, erased s: String) = p
erasedfn2("a", "b")("c", "d")
})
println("case erased nested 2: " + testExpr {
def erasedfn2(p: String, erased q: String)(erased r: String, s: String) = p
erasedfn2("a", "b")("c", "d")
})
8 changes: 8 additions & 0 deletions tests/run-macros/i17105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
case single: [1st case] arg1 outside
case no-param-method (will be eta-expanded): [1st case] placeholder 2
case curried: [2nd case] arg1, arg2 outside
case methods from outer scope: [1st case] arg1 outer-method
case refinement: Hoe got 1
case dependent: 1
case dependent2: 1
case dependent3: 1
15 changes: 15 additions & 0 deletions tests/run-macros/i17105/Lib1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

// Test case for dependent types
trait DSL {
type N
def toString(n: N): String
val zero: N
def next(n: N): N
}

object IntDSL extends DSL {
type N = Int
def toString(n: N): String = n.toString()
val zero = 0
def next(n: N): N = n + 1
}
34 changes: 34 additions & 0 deletions tests/run-macros/i17105/Macro_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import scala.quoted.*
import language.experimental.erasedDefinitions

inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
body match
case '{ def g(y: String) = "placeholder" + y; $a(g): String } =>
'{ $a((z: String) => s"[1st case] ${z}") }
case '{ def g(y: String)(z: String) = "placeholder" + y; $a(g): String } =>
'{ $a((z1: String) => (z2: String) => s"[2nd case] ${z1}, ${z2}") }
// Refined Types
case '{
type t
def refined(a: `t`): String = $x(a): String
$y(refined): String
} =>
'{ $y($x) }
// Dependent Types
case '{
def p(dsl: DSL): dsl.N = dsl.zero
$y(p): String
} =>
'{ $y((dsl1: DSL) => dsl1.next(dsl1.zero)) }
case '{
def p(dsl: DSL)(a: dsl.N): dsl.N = a
$y(p): String
} =>
'{ $y((dsl: DSL) => (b2: dsl.N) => dsl.next(b2)) }
case '{
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
$y(p): String
} =>
'{ $y((dsl1: DSL) => (dsl2: DSL) => dsl2.next(dsl2.zero)) }
case _ => Expr("not matched")
23 changes: 23 additions & 0 deletions tests/run-macros/i17105/Test_3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import reflect.Selectable.reflectiveSelectable

class Hoe { def f(x: Int): String = s"Hoe got ${x}" }

@main def Test: Unit =
println("case single: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + " outside" })
println("case no-param-method (will be eta-expanded): " + testExpr { def f(x: String) = "placeholder" + x; (() => f)()("placeholder 2") })
println("case curried: " + testExpr { def f(x: String)(y: String) = "placeholder" + x; f("arg1")("arg2") + " outside" })
def outer() = " outer-method"
println("case methods from outer scope: " + testExpr { def f(x: String) = "placeholder" + x; f("arg1") + outer() })
println("case refinement: " + testExpr { def refined(a: { def f(x: Int): String }): String = a.f(1); refined(Hoe()) })
println("case dependent: " + testExpr {
def p(a: DSL): a.N = a.zero
IntDSL.toString(p(IntDSL))
})
println("case dependent2: " + testExpr {
def p(dsl1: DSL)(c: dsl1.N): dsl1.N = c
IntDSL.toString(p(IntDSL)(IntDSL.zero))
})
println("case dependent3: " + testExpr {
def p(dsl1: DSL)(dsl2: DSL): dsl2.N = dsl2.zero
IntDSL.toString(p(IntDSL)(IntDSL))
})