Skip to content

Commit

Permalink
Merge pull request #8665 from dotty-staging/fix-#8585-2
Browse files Browse the repository at this point in the history
Fix #8585: Refresh names to avoid name clashes
  • Loading branch information
nicolasstucki committed Apr 5, 2020
2 parents 3d26c53 + 9b5573a commit 895e361
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 137 deletions.
19 changes: 17 additions & 2 deletions library/src/scala/tasty/reflect/SourceCodePrinter.scala
Expand Up @@ -324,7 +324,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig

case tree: Ident =>
splicedName(tree.symbol) match {
case Some(name) => this += name
case Some(name) => this += highlightTypeDef(name)
case _ => printType(tree.tpe)
}

Expand Down Expand Up @@ -834,7 +834,7 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
}

def printParamDef(arg: ValDef)(using elideThis: Option[Symbol]): Unit = {
val name = arg.name
val name = splicedName(arg.symbol).getOrElse(arg.symbol.name)
val sym = arg.symbol.owner
if sym.isDefDef && sym.name == "<init>" then
val ClassDef(_, _, _, _, _, body) = sym.owner.tree
Expand Down Expand Up @@ -1409,11 +1409,26 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
private def escapedString(str: String): String = str flatMap escapedChar
}

private[this] val names = collection.mutable.Map.empty[Symbol, String]
private[this] val namesIndex = collection.mutable.Map.empty[String, Int]

private def splicedName(sym: Symbol)(using ctx: Context): Option[String] = {
sym.annots.find(_.symbol.owner == ctx.requiredClass("scala.internal.quoted.showName")).flatMap {
case Apply(_, Literal(Constant(c: String)) :: Nil) => Some(c)
case Apply(_, Inlined(_, _, Literal(Constant(c: String))) :: Nil) => Some(c)
case annot => None
}.orElse {
if sym.owner.isClassDef then None
else names.get(sym).orElse {
val name0 = sym.name
val index = namesIndex.getOrElse(name0, 1)
namesIndex(name0) = index + 1
val name =
if index == 1 then name0
else s"`$name0${index.toString.toCharArray.map {x => (x - '0' + '₀').toChar}.mkString}`"
names(sym) = name
Some(name)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/run-macros/flops-rewrite.check
@@ -1,7 +1,7 @@
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x))
scala.Nil

scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Nothing](scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)))
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Nothing](scala.Nil.map[scala.Nothing](((`x₂`: scala.Nothing) => `x₂`)))
scala.Nil

scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Int](scala.List.apply[scala.Int](3)).++[scala.Int](scala.Nil)
Expand Down
8 changes: 4 additions & 4 deletions tests/run-macros/quote-matching-optimize-1.check
@@ -1,18 +1,18 @@
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1)))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1)))
Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1))))
Result: List(2)

Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a')))
Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((`x₂`: scala.Char) => `x₂`.>('a')))
Optimized: ls2.filter(((x: scala.Char) => x.<('c').&&(x.>('a'))))
Result: List(b)

Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2)))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1))).filter(((`x₃`: scala.Int) => `x₃`.==(2)))
Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1).&&(x.==(2)))))
Result: List(2)

1
2
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((`x₂`: scala.Int) => scala.Predef.println(`x₂`)))
Optimized: ls.foreach[scala.Unit](((x: scala.Int) => if (x.<(3)) scala.Predef.println(x) else ()))
Result: ()

Expand Down
8 changes: 4 additions & 4 deletions tests/run-macros/quote-matching-optimize-2.check
@@ -1,18 +1,18 @@
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1)))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1)))
Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1))))
Result: List(2)

Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a')))
Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((`x₂`: scala.Char) => `x₂`.>('a')))
Optimized: ls2.filter(((x: scala.Char) => x.<('c').&&(x.>('a'))))
Result: List(b)

Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2)))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1))).filter(((`x₃`: scala.Int) => `x₃`.==(2)))
Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1).&&(x.==(2)))))
Result: List(2)

1
2
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((`x₂`: scala.Int) => scala.Predef.println(`x₂`)))
Optimized: ls.foreach[scala.Any](((x: scala.Int) => if (x.<(3)) scala.Predef.println(x) else ()))
Result: ()

Expand Down
16 changes: 8 additions & 8 deletions tests/run-macros/quote-matching-optimize-3.check
@@ -1,19 +1,19 @@
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1)))
Optimized: ls.filter(((x: scala.Int) => ((x: scala.Int) => x.<(3)).apply(x).&&(((x: scala.Int) => x.>(1)).apply(x))))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1)))
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => `x₃`.>(1)).apply(x))))
Result: List(2)

Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a')))
Optimized: ls2.filter(((x: scala.Char) => ((x: scala.Char) => x.<('c')).apply(x).&&(((x: scala.Char) => x.>('a')).apply(x))))
Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((`x₂`: scala.Char) => `x₂`.>('a')))
Optimized: ls2.filter(((x: scala.Char) => ((`x₂`: scala.Char) => `x₂`.<('c')).apply(x).&&(((`x₃`: scala.Char) => `x₃`.>('a')).apply(x))))
Result: List(b)

Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2)))
Optimized: ls.filter(((x: scala.Int) => ((x: scala.Int) => x.<(3)).apply(x).&&(((x: scala.Int) => ((x: scala.Int) => x.>(1)).apply(x).&&(((x: scala.Int) => x.==(2)).apply(x))).apply(x))))
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1))).filter(((`x₃`: scala.Int) => `x₃`.==(2)))
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => ((`x₄`: scala.Int) => `x₄`.>(1)).apply(`x₃`).&&(((`x₅`: scala.Int) => `x₅`.==(2)).apply(`x₃`))).apply(x))))
Result: List(2)

1
2
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
Optimized: ls.foreach[scala.Any](((x: scala.Int) => if (((x: scala.Int) => x.<(3)).apply(x)) ((x: scala.Int) => scala.Predef.println(x)).apply(x) else ()))
Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((`x₂`: scala.Int) => scala.Predef.println(`x₂`)))
Optimized: ls.foreach[scala.Any](((x: scala.Int) => if (((`x₂`: scala.Int) => `x₂`.<(3)).apply(x)) ((`x₃`: scala.Int) => scala.Predef.println(`x₃`)).apply(x) else ()))
Result: ()

Original: ls.map[scala.Long](((a: scala.Int) => a.toLong)).map[java.lang.String](((b: scala.Long) => b.toString()))
Expand Down
2 changes: 1 addition & 1 deletion tests/run-staging/i3876-c.check
Expand Up @@ -2,7 +2,7 @@
{
val f: scala.Function1[scala.Int, scala.Int] {
def apply(x: scala.Int): scala.Int
} = ((x: scala.Int) => x.+(x))
} = ((`x₂`: scala.Int) => `x₂`.+(`x₂`))

(f: scala.Function1[scala.Int, scala.Int] {
def apply(x: scala.Int): scala.Int
Expand Down
4 changes: 2 additions & 2 deletions tests/run-staging/i5144.check
@@ -1,4 +1,4 @@
{
def f(x: scala.Int): scala.Int = ((x: scala.Int) => f(x)).apply(42)
def f(x: scala.Int): scala.Int = ((`x₂`: scala.Int) => f(`x₂`)).apply(42)
()
}
}
22 changes: 22 additions & 0 deletions tests/run-staging/i8585.check
@@ -0,0 +1,22 @@
The following would not compile:
((x: scala.Double) => {
val y: scala.Double = x.*(x)
val `y₂`: scala.Double = y.*(y)
`y₂`.*(`y₂`)
})
3^8 = 6561.0
The following would not compile:
((x: scala.Double) => x.*({
val y: scala.Double = x.*(x)
y.*({
val `y₂`: scala.Double = y.*(y)
`y₂`.*({
val `y₃`: scala.Double = `y₂`.*(`y₂`)
`y₃`.*({
val `y₄`: scala.Double = `y₃`.*(`y₃`)
`y₄`.*(`y₄`)
})
})
})
}))
2^47 = 1.40737488355328E14
27 changes: 27 additions & 0 deletions tests/run-staging/i8585.scala
@@ -0,0 +1,27 @@
import scala.quoted._
import scala.quoted.staging.{run, withQuoteContext, Toolbox}

object Test {
given Toolbox = Toolbox.make(getClass.getClassLoader)

def main(args: Array[String]): Unit = {
val toTheEighth = stagedPower(8)
println("3^8 = " + toTheEighth(3))

val toThe47 = stagedPower(47)
println("2^47 = " + toThe47(2))
}

def stagedPower(n: Int): Double => Double = {
def code(using QuoteContext) = '{ (x: Double) => ${ powerCode(n, 'x) } }
println("The following would not compile:")
println(withQuoteContext(code.show))
run(code)
}

def powerCode(n: Int, x: Expr[Double])(using ctx: QuoteContext): Expr[Double] =
if (n == 1) x
else if (n == 2) '{ $x * $x }
else if (n % 2 == 1) '{ $x * ${ powerCode(n - 1, x) } }
else '{ val y = $x * $x; ${ powerCode(n / 2, 'y) } }
}
10 changes: 5 additions & 5 deletions tests/run-staging/quote-run-2.check
Expand Up @@ -11,11 +11,11 @@
{
val y: scala.Double = 5.0.*(5.0)
y.*({
val y: scala.Double = y.*(y)
y.*({
val y: scala.Double = y.*(y)
val y: scala.Double = y.*(y)
y
val `y₂`: scala.Double = y.*(y)
`y₂`.*({
val `y₃`: scala.Double = `y₂`.*(`y₂`)
val `y₄`: scala.Double = `y₃`.*(`y₃`)
`y₄`
})
})
}
2 changes: 1 addition & 1 deletion tests/run-staging/quote-unrolled-foreach.check
Expand Up @@ -33,7 +33,7 @@
var i: scala.Int = 0
while (i.<(size)) {
val element: scala.Int = arr.apply(i)
((i: scala.Int) => java.lang.System.out.println(i)).apply(element)
((`i₂`: scala.Int) => java.lang.System.out.println(`i₂`)).apply(element)
i = i.+(1)
}
})
Expand Down

0 comments on commit 895e361

Please sign in to comment.