Skip to content

Commit

Permalink
Redefine quoted.Expr.betaReduce
Browse files Browse the repository at this point in the history
Redefine `quoted.Expr.betaRduce` to not rely on complex type level computations.
Changed the signature as follows
```diff
- def betaReduce[F, Args <: Tuple, R, G](f: Expr[F])(using tf: TupledFunction[F, Args => R], tg: TupledFunction[G, TupleOfExpr[Args] => Expr[R]], qctx: QuoteContext): G = ...
+ def betaReduce[T](expr: Expr[T])(using qctx: QuoteContext): Option[Expr[T]] = ...
```

Improvements
* Simpler API that covers all kind of functions at once (normal/given/erased)
* Better error message for ill-typed `betaRduce` calls
* Adds the possiblility of knowing if the beta-reeduction suceeded
* Use `transform.BetaReduce`
* Improve `transform.BetaReduce` to handle `Inlined` trees and constant argumets
* Fixes #9466

Drawback
* Need for slightly loneger code (previous interface could be implented on top of this one)
  • Loading branch information
nicolasstucki committed Jul 30, 2020
1 parent 316d218 commit 34d2587
Show file tree
Hide file tree
Showing 43 changed files with 845 additions and 830 deletions.
Expand Up @@ -2042,21 +2042,18 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
case _ => None
}

def betaReduce(fn: Term, args: List[Term])(using Context): Term = {
val (argVals0, argRefs0) = args.foldLeft((List.empty[ValDef], List.empty[Tree])) { case ((acc1, acc2), arg) => arg.tpe match {
case tpe: SingletonType if isIdempotentExpr(arg) => (acc1, arg :: acc2)
def betaReduce(tree: Term)(using Context): Option[Term] =
tree match
case app @ Apply(Select(fn, nme.apply), args) if defn.isFunctionType(fn.tpe) =>
val app1 = transform.BetaReduce(app, fn, args)
if app1 eq app then None
else Some(app1.withSpan(tree.span))
case Block(Nil, expr) =>
for e <- betaReduce(expr) yield cpy.Block(tree)(Nil, e)
case Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
val argVal = SyntheticValDef(NameKinds.UniqueName.fresh("x".toTermName), arg).withSpan(arg.span)
(argVal :: acc1, ref(argVal.symbol) :: acc2)
}}
val argVals = argVals0.reverse
val argRefs = argRefs0.reverse
val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match {
case Some(body) => body(argRefs)
case None => fn.select(nme.apply).appliedToArgs(argRefs)
}
seq(argVals, reducedBody).withSpan(fn.span)
}
None

def lambdaExtractor(fn: Term, paramTypes: List[Type])(using Context): Option[List[Term] => Term] = {
def rec(fn: Term, transformBody: Term => Term): Option[List[Term] => Term] = {
Expand Down
31 changes: 22 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Expand Up @@ -37,22 +37,26 @@ class BetaReduce extends MiniPhase:

override def transformApply(app: Apply)(using Context): Tree = app.fun match
case Select(fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
val app1 = betaReduce(app, fn, app.args)
val app1 = BetaReduce(app, fn, app.args)
if app1 ne app then report.log(i"beta reduce $app -> $app1")
app1
case _ =>
app

private def betaReduce(tree: Apply, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) => betaReduce(tree, expr, args)
case Block(Nil, expr) => betaReduce(tree, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) => BetaReduce(anonFun, args)
case _ => tree

object BetaReduce:
import ast.tpd._

/** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
def apply(tree: Apply, fn: Tree, args: List[Tree])(using Context): Tree =
fn match
case Typed(expr, _) => BetaReduce(tree, expr, args)
case Block(Nil, expr) => BetaReduce(tree, expr, args)
case Inlined(_, Nil, expr) => BetaReduce(tree, expr, args)
case Block((anonFun: DefDef) :: Nil, closure: Closure) => BetaReduce(anonFun, args)
case _ => tree
end apply

/** Beta-reduces a call to `ddef` with arguments `argSyms` */
def apply(ddef: DefDef, args: List[Tree])(using Context) =
val bindings = List.newBuilder[ValDef]
Expand All @@ -65,7 +69,8 @@ object BetaReduce:
ref.symbol
case _ =>
val flags = Synthetic | (param.symbol.flags & Erased)
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, arg.tpe.widen, coord = arg.span), arg)
val tpe = if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias else arg.tpe.widen
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span)
bindings += binding
binding.symbol

Expand All @@ -76,5 +81,13 @@ object BetaReduce:
substTo = argSyms
).transform(ddef.rhs)

seq(bindings.result(), expansion)
val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case _ => super.transform(tree)
}.transform(expansion)
val bindings1 =
bindings.result().filterNot(vdef => vdef.tpt.tpe.isInstanceOf[ConstantType] && isPureExpr(vdef.rhs))

seq(bindings1, expansion1)
end apply
41 changes: 30 additions & 11 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Expand Up @@ -399,22 +399,15 @@ class InlineBytecodeTests extends DottyBytecodeTest {
val instructions = instructionsFromMethod(fun)
val expected =
List(
// Head tested separatly
VarOp(ALOAD, 0),
Invoke(INVOKEVIRTUAL, "Test", "given_Int", "()I", false),
Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false),
Invoke(INVOKEINTERFACE, "dotty/runtime/function/JFunction1$mcZI$sp", "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", true),
Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToBoolean", "(Ljava/lang/Object;)Z", false),
VarOp(ISTORE, 1),
Op(ICONST_1),
Op(IRETURN)
)

instructions.head match {
case InvokeDynamic(INVOKEDYNAMIC, "apply$mcZI$sp", "()Ldotty/runtime/function/JFunction1$mcZI$sp;", _, _) =>
case _ => assert(false, "`g` was not properly inlined in `test`\n")
}

assert(instructions.tail == expected,
"`fg was not properly inlined in `test`\n" + diffInstructions(instructions.tail, expected))
assert(instructions == expected,
"`fg was not properly inlined in `test`\n" + diffInstructions(instructions, expected))

}
}
Expand Down Expand Up @@ -505,4 +498,30 @@ class InlineBytecodeTests extends DottyBytecodeTest {
}
}


@Test def i9466 = {
val source = """class Test:
| inline def i(inline f: Int => Boolean): String =
| if f(34) then "a"
| else "b"
| def test = i(f = _ == 34)
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected =
List(
Ldc(LDC, "a"),
Op(ARETURN)
)

assert(instructions == expected,
"`i was not properly inlined in `test`\n" + diffInstructions(instructions, expected))

}
}
}
29 changes: 6 additions & 23 deletions library/src-bootstrapped/scala/quoted/Expr.scala
Expand Up @@ -61,30 +61,13 @@ abstract class Expr[+T] private[scala] {

object Expr {

/** Converts a tuple `(T1, ..., Tn)` to `(Expr[T1], ..., Expr[Tn])` */
type TupleOfExpr[Tup <: Tuple] = Tuple.Map[Tup, [X] =>> QuoteContext ?=> Expr[X]]

/** `Expr.betaReduce(f)(x1, ..., xn)` is functionally the same as `'{($f)($x1, ..., $xn)}`, however it optimizes this call
* by returning the result of beta-reducing `f(x1, ..., xn)` if `f` is a known lambda expression.
*
* `Expr.betaReduce` distributes applications of `Expr` over function arrows
* ```scala
* Expr.betaReduce(_): Expr[(T1, ..., Tn) => R] => ((Expr[T1], ..., Expr[Tn]) => Expr[R])
* ```
*/
def betaReduce[F, Args <: Tuple, R, G](f: Expr[F])(using tf: TupledFunction[F, Args => R], tg: TupledFunction[G, TupleOfExpr[Args] => Expr[R]], qctx: QuoteContext): G =
tg.untupled(args => qctx.tasty.internal.betaReduce(f.unseal, args.toArray.toList.map(_.asInstanceOf[QuoteContext => Expr[Any]](qctx).unseal)).seal.asInstanceOf[Expr[R]])

/** `Expr.betaReduceGiven(f)(x1, ..., xn)` is functionally the same as `'{($f)(using $x1, ..., $xn)}`, however it optimizes this call
* by returning the result of beta-reducing `f(using x1, ..., xn)` if `f` is a known lambda expression.
*
* `Expr.betaReduceGiven` distributes applications of `Expr` over function arrows
* ```scala
* Expr.betaReduceGiven(_): Expr[(T1, ..., Tn) ?=> R] => ((Expr[T1], ..., Expr[Tn]) => Expr[R])
* ```
/** `e.betaReduce` returns a option with a expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(x1, ..., xn)`
* then it optimizes this the top most call by returning `Some` with the result of beta-reducing the application.
* Otherwise returns None.
*/
def betaReduceGiven[F, Args <: Tuple, R, G](f: Expr[F])(using tf: TupledFunction[F, Args ?=> R], tg: TupledFunction[G, TupleOfExpr[Args] => Expr[R]], qctx: QuoteContext): G =
tg.untupled(args => qctx.tasty.internal.betaReduce(f.unseal, args.toArray.toList.map(_.asInstanceOf[QuoteContext => Expr[Any]](qctx).unseal)).seal.asInstanceOf[Expr[R]])
def betaReduce[T](expr: Expr[T])(using qctx: QuoteContext): Option[Expr[T]] =
qctx.tasty.internal.betaReduce(expr.unseal).map(_.seal.asInstanceOf[Expr[T]])

/** Returns a null expresssion equivalent to `'{null}` */
def nullExpr: QuoteContext ?=> Expr[Null] = qctx ?=> {
Expand Down
Expand Up @@ -22,7 +22,7 @@ object UnsafeExpr {
def underlyingArgument[T](expr: Expr[T])(using qctx: QuoteContext): Expr[T] =
expr.unseal.underlyingArgument.seal.asInstanceOf[Expr[T]]

// TODO generalize for any function arity (see Expr.betaReduce)
// TODO generalize for any function arity
/** Allows inspection or transformation of the body of the expression of function.
* This body may have references to the arguments of the function which should be closed
* over if the expression will be spliced.
Expand Down
6 changes: 2 additions & 4 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Expand Up @@ -1600,10 +1600,8 @@ trait CompilerInterface {
*/
def searchImplicit(tpe: Type)(using ctx: Context): ImplicitSearchResult

/** Inline fn if it is an explicit closure possibly nested inside the expression of a block.
* Otherwise apply the arguments to the closure.
*/
def betaReduce(f: Term, args: List[Term])(using ctx: Context): Term
/** Returns Some with a beta-reduced application or None */
def betaReduce(tree: Term)(using Context): Option[Term]

def lambdaExtractor(term: Term, paramTypes: List[Type])(using ctx: Context): Option[List[Term] => Term]

Expand Down
4 changes: 2 additions & 2 deletions tests/neg-macros/beta-reduce-inline-result/Macro_1.scala
Expand Up @@ -4,7 +4,7 @@ object Macros {
inline def betaReduce[Arg,Result](inline fn: Arg=>Result)(inline arg: Arg): Result =
${ betaReduceImpl('{ fn })('{ arg }) }

def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] =
Expr.betaReduce(fn)(arg)
def betaReduceImpl[Arg: Type, Result: Type](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] =
Expr.betaReduce('{$fn($arg)}).get
}

2 changes: 1 addition & 1 deletion tests/pos-macros/i6783.scala
@@ -1,6 +1,6 @@
import scala.quoted._

def testImpl(f: Expr[(Int, Int) => Int])(using QuoteContext): Expr[Int] = Expr.betaReduce(f)('{1}, '{2})
def testImpl(f: Expr[(Int, Int) => Int])(using QuoteContext): Expr[Int] = Expr.betaReduce('{$f(1, 2)}).get

inline def test(f: (Int, Int) => Int) = ${
testImpl('f)
Expand Down
4 changes: 2 additions & 2 deletions tests/run-macros/beta-reduce-inline-result.check
@@ -1,6 +1,6 @@
compile-time: ((3.+(1): scala.Int): scala.Int)
compile-time: (4: scala.Int)
run-time: 4
compile-time: ((1: 1): scala.Int)
compile-time: (1: scala.Int)
run-time: 1
run-time: 5
run-time: 7
Expand Down
10 changes: 6 additions & 4 deletions tests/run-macros/beta-reduce-inline-result/Macro_1.scala
Expand Up @@ -4,13 +4,15 @@ object Macros {
inline def betaReduce[Arg,Result](inline fn : Arg=>Result)(inline arg: Arg): Result =
${ betaReduceImpl('{ fn })('{ arg }) }

def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] =
Expr.betaReduce(fn)(arg)
def betaReduceImpl[Arg: Type, Result: Type](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] =
val app = '{$fn($arg)}
Expr.betaReduce(app).getOrElse(app)

inline def betaReduceAdd1[Arg](inline fn: Arg=>Int)(inline arg: Arg): Int =
${ betaReduceAdd1Impl('{ fn })('{ arg }) }

def betaReduceAdd1Impl[Arg](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] =
'{ ${ Expr.betaReduce(fn)(arg) } + 1 }
def betaReduceAdd1Impl[Arg: Type](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] =
val app = '{$fn.asInstanceOf[Arg=>Int]($arg)} // FIXME: remove asInstanceOf (workaround for #8612)
'{ ${ Expr.betaReduce(app).getOrElse(app) } + 1 }
}

1 change: 0 additions & 1 deletion tests/run-macros/beta-reduce-inline-result/Test_2.scala
Expand Up @@ -74,4 +74,3 @@ object Test {
println(s"run-time: ${Macros.betaReduce(dummy7)(8)}")
}
}

4 changes: 2 additions & 2 deletions tests/run-macros/gestalt-optional-staging/Macro_1.scala
Expand Up @@ -9,7 +9,7 @@ final class Optional[+A >: Null](val value: A) extends AnyVal {

inline def getOrElse[B >: A](alt: => B): B = ${ Optional.getOrElseImpl('this, 'alt) }

inline def map[B >: Null](f: A => B): Optional[B] = ${ Optional.mapImpl('this, 'f) }
inline def map[B >: Null](inline f: A => B): Optional[B] = ${ Optional.mapImpl('this, 'f) }

override def toString = if (isEmpty) "<empty>" else s"$value"
}
Expand All @@ -24,7 +24,7 @@ object Optional {
// FIXME fix issue #5097 and enable private
/*private*/ def mapImpl[A >: Null : Type, B >: Null : Type](opt: Expr[Optional[A]], f: Expr[A => B])(using QuoteContext): Expr[Optional[B]] = '{
if ($opt.isEmpty) new Optional(null)
else new Optional(${Expr.betaReduce(f)('{$opt.value})})
else new Optional(${Expr.betaReduce('{$f($opt.value)}).get})
}

}
4 changes: 2 additions & 2 deletions tests/run-macros/i4734/Macro_1.scala
Expand Up @@ -2,7 +2,7 @@ import scala.annotation.tailrec
import scala.quoted._

object Macros {
inline def unrolledForeach(seq: IndexedSeq[Int], f: => Int => Unit, inline unrollSize: Int): Unit = // or f: Int => Unit
inline def unrolledForeach(seq: IndexedSeq[Int], inline f: Int => Unit, inline unrollSize: Int): Unit = // or f: Int => Unit
${ unrolledForeachImpl('seq, 'f, 'unrollSize) }

def unrolledForeachImpl(seq: Expr[IndexedSeq[Int]], f: Expr[Int => Unit], unrollSizeExpr: Expr[Int]) (using QuoteContext): Expr[Unit] =
Expand All @@ -17,7 +17,7 @@ object Macros {
for (j <- new UnrolledRange(0, unrollSize)) '{
val index = i + ${Expr(j)}
val element = ($seq)(index)
${ Expr.betaReduce(f)('element) } // or `($f)(element)` if `f` should not be inlined
${ Expr.betaReduce('{$f(element)}).get } // or `($f)(element)` if `f` should not be inlined
}
}
i += ${Expr(unrollSize)}
Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/i4735/App_2.scala
Expand Up @@ -10,5 +10,5 @@ object Test {
}

class Unrolled(arr: Array[Int]) extends AnyVal {
inline def foreach(f: => Int => Unit): Unit = Macro.unrolledForeach(3, arr, f)
inline def foreach(inline f: Int => Unit): Unit = Macro.unrolledForeach(3, arr, f)
}
4 changes: 2 additions & 2 deletions tests/run-macros/i4735/Macro_1.scala
Expand Up @@ -4,7 +4,7 @@ import scala.quoted._

object Macro {

inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int], f: => Int => Unit): Unit = // or f: Int => Unit
inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int], inline f: Int => Unit): Unit = // or f: Int => Unit
${ unrolledForeachImpl('unrollSize, 'seq, 'f) }

private def unrolledForeachImpl(unrollSize: Expr[Int], seq: Expr[Array[Int]], f: Expr[Int => Unit]) (using QuoteContext): Expr[Unit] = '{
Expand All @@ -16,7 +16,7 @@ object Macro {
${
for (j <- new UnrolledRange(0, unrollSize.unliftOrError)) '{
val element = ($seq)(i + ${Expr(j)})
${Expr.betaReduce(f)('element)} // or `($f)(element)` if `f` should not be inlined
${Expr.betaReduce('{$f(element)}).get} // or `($f)(element)` if `f` should not be inlined
}
}
i += ${unrollSize}
Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/i7008/macro_1.scala
Expand Up @@ -13,5 +13,5 @@ def mcrProxy(expr: Expr[Boolean])(using QuoteContext): Expr[Unit] = {
def mcrImpl[T](func: Expr[Seq[Box[T]] => Unit], expr: Expr[T])(using ctx: QuoteContext, tt: Type[T]): Expr[Unit] = {
import ctx.tasty._
val arg = Varargs(Seq('{(Box($expr))}))
Expr.betaReduce(func)(arg)
Expr.betaReduce('{$func($arg)}).get
}
20 changes: 8 additions & 12 deletions tests/run-macros/quote-inline-function.check
Expand Up @@ -3,13 +3,11 @@ Normal function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$1: scala.Int = i
f.apply(x$1)
f.apply(i)
i = i.+(1)
}
while ({
val x$2: scala.Int = i
f.apply(x$2)
f.apply(i)
i = i.+(1)
i.<(j)
}) ()
Expand All @@ -20,13 +18,11 @@ By name function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$3: scala.Int = i
f.apply(x$3)
f.apply(i)
i = i.+(1)
}
while ({
val x$4: scala.Int = i
f.apply(x$4)
f.apply(i)
i = i.+(1)
i.<(j)
}) ()
Expand All @@ -37,13 +33,13 @@ Inline function
var i: scala.Int = 0
val j: scala.Int = 5
while (i.<(j)) {
val x$5: scala.Int = i
scala.Predef.println(x$5)
val x: scala.Int = i
scala.Predef.println(x)
i = i.+(1)
}
while ({
val x$6: scala.Int = i
scala.Predef.println(x$6)
val `x₂`: scala.Int = i
scala.Predef.println(`x₂`)
i = i.+(1)
i.<(j)
}) ()
Expand Down
4 changes: 2 additions & 2 deletions tests/run-macros/quote-inline-function/quoted_1.scala
Expand Up @@ -12,11 +12,11 @@ object Macros {
var i = $start
val j = $end
while (i < j) {
${Expr.betaReduce(f)('i)}
${Expr.betaReduce('{$f(i)}).getOrElse('{$f(i)})}
i += 1
}
while {
${Expr.betaReduce(f)('i)}
${Expr.betaReduce('{$f(i)}).getOrElse('{$f(i)})}
i += 1
i < j
} do ()
Expand Down

0 comments on commit 34d2587

Please sign in to comment.