Skip to content

Commit

Permalink
Elide unit binding when beta-reducing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasstucki committed Apr 4, 2024
1 parent ece87c3 commit f055cee
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/BetaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import MegaPhase.*
import Symbols.*, Contexts.*, Types.*, Decorators.*
import StdNames.nme
import ast.TreeTypeMap
import Constants.Constant

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -133,7 +134,7 @@ object BetaReduce:
else if arg.tpe.dealias.isInstanceOf[ConstantType] then arg.tpe.dealias
else arg.tpe.widen
val binding = ValDef(newSymbol(ctx.owner, param.name, flags, tpe, coord = arg.span), arg).withSpan(arg.span)
if !(tpe.isInstanceOf[ConstantType] && isPureExpr(arg)) then
if !((tpe.isInstanceOf[ConstantType] || tpe.derivesFrom(defn.UnitClass)) && isPureExpr(arg)) then
bindings += binding
binding.symbol

Expand All @@ -147,6 +148,7 @@ object BetaReduce:
val expansion1 = new TreeMap {
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
case tpe: TypeRef if tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) => cpy.Literal(tree)(Constant(()))
case _ => super.transform(tree)
}.transform(expansion)

Expand Down
20 changes: 20 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -765,4 +765,24 @@ class InlineBytecodeTests extends DottyBytecodeTest {
diffInstructions(instructions1, instructions2))
}
}

@Test def beta_reduce_elide_unit_binding = {
val source = """class Test:
| def test = ((u: Unit) => u).apply(())
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

val fun = getMethod(clsNode, "test")
val instructions = instructionsFromMethod(fun)
val expected = List(Op(RETURN))

assert(instructions == expected,
"`i was not properly beta-reduced in `test`\n" + diffInstructions(instructions, expected))

}
}

}

0 comments on commit f055cee

Please sign in to comment.