Skip to content

Commit

Permalink
WIP: Construct trees for HOAS patterns respecting eta expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptometer committed Jun 4, 2023
1 parent 209d763 commit c7a88fa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 29 deletions.
19 changes: 18 additions & 1 deletion compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.util.optional
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Definitions
import dotty.tools.dotc.ast.untpd

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
Expand Down Expand Up @@ -508,14 +509,30 @@ object QuoteMatcher {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
val paramTypes = args.map(x => adaptTypes(mapTypeHoles(x.tpe.widenTermRefExpr)))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should be `g.apply(0)`
* because the type of `g` is `Int => Int` due to eta expansion.
*
* Remaining TODOs from issue-17105
* * [ ] cover the case of nested method call
* * [ ] contextual params?
* * [ ] erasure types?
*/
case Apply(methId: Ident, args) =>
val fnId = env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
ctx.typer.typed(
untpd.Apply(
untpd.Select(untpd.TypedSplice(fnId), nme.apply),
args.map(untpd.TypedSplice(_))))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
Expand Down
55 changes: 31 additions & 24 deletions tests/run-macros/i17105/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
import scala.quoted.*

inline def test1: String = ${ testExpr1 }
def testExpr1(using Quotes): Expr[String] =
'{ def f(x: Int) = 1; val n = 2; f(n) } match
case '{ def g(y: Int) = 1; val n = 2; $a(g, n): Int } => Expr(a.show)
inline def testExpr(inline body: Any) = ${ testExprImpl('body) }
def testExprImpl(body: Expr[Any])(using Quotes): Expr[String] =
body match
case '{ def g(y: Int) = "hello" * y; $a(g): String } =>
'{ $a((z:Int) => "this is " + z.toString()) }
case _ => Expr("not matched")

inline def test2: String = ${ testExpr2 }
def testExpr2(using Quotes): Expr[String] =
'{ def f(x: Int, y:Int) = 1; f(1, 2) } match
case '{ def g(y: Int, z:Int) = 1; $a(g): Int } => Expr(a.show)
case _ => Expr("not matched")
// TODO issue-17105: Clean this up if not neccessary
// inline def test1: String = ${ testExpr1 }
// def testExpr1(using Quotes): Expr[String] =
// '{ def f(x: Int) = 1; val n = 2; f(n) } match
// case '{ def g(y: Int) = 1; val n = 2; $a(g, n): Int } => Expr(a.show)
// case _ => Expr("not matched")

inline def test3: String = ${ testExpr3 }
def testExpr3(using Quotes): Expr[String] =
'{
def f1(using Ordered[Int]) =
def f2(using Ordered[Int]) =
1 < 2
f2 || 2 < 3: Boolean
} match
case '{
def g1(using ord: Ordered[Int]) =
def g2(using Ordered[Int]) =
1 < 2
$a(g2, ord): Boolean
} => Expr(a.show)
case _ => Expr("not matched")
// inline def test2: String = ${ testExpr2 }
// def testExpr2(using Quotes): Expr[String] =
// '{ def f(x: Int, y:Int) = 1; f(1, 2) } match
// case '{ def g(y: Int, z:Int) = 1; $a(g): Int } => Expr(a.show)
// case _ => Expr("not matched")

// inline def test3: String = ${ testExpr3 }
// def testExpr3(using Quotes): Expr[String] =
// '{
// def f1(using Ordered[Int]) =
// def f2(using Ordered[Int]) =
// 1 < 2
// f2 || 2 < 3: Boolean
// } match
// case '{
// def g1(using ord: Ordered[Int]) =
// def g2(using Ordered[Int]) =
// 1 < 2
// $a(g2, ord): Boolean
// } => Expr(a.show)
// case _ => Expr("not matched")
6 changes: 2 additions & 4 deletions tests/run-macros/i17105/Test_2.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
@main def Test: Unit =
println(test1)
println(test2)
println(test3)
@main def app: Unit =
testExpr { def f(x: Int) = "hello" * x; f(0) + "bye" }

0 comments on commit c7a88fa

Please sign in to comment.