Skip to content

Commit

Permalink
Merge pull request #558 from scalacenter/anonfun
Browse files Browse the repository at this point in the history
find anonymous function symbols
  • Loading branch information
aymanelamyaghri committed Aug 7, 2023
2 parents 4f76bcc + b8f2965 commit fa78473
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ object LazyInit:
val lazyInit = "(.*)\\$lzyINIT\\d+".r
lazyInit.unapplySeq(NameTransformer.decode(method.name)).map(xs => xs(0))

object AnonFun:
def unapply(method: binary.Method): Option[String] =
val anonFun = "(.*)\\$anonfun\\$\\d+".r
anonFun.unapplySeq(NameTransformer.decode(method.name)).map(xs => xs(0).stripSuffix("$"))

object LocalMethod:
def unapply(method: binary.Method): Option[(String, Int)] =
if method.name.contains("$default") || method.name.contains("$proxy") then None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ class Scala3Unpickler(
case t: TermSymbol if (t.isLazyVal || t.isModuleVal) && t.matchName(name) => t
}
yield term
case AnonFun(prefix) =>
val symbols =
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case t: TermSymbol if t.isAnonFun && matchSignature(method, t) => t
}
yield term
if symbols.size > 1 && prefix.nonEmpty then
val filteredSymbols = symbols.filter(s => matchPrefix(prefix, s.owner))
if filteredSymbols.size == 0 then symbols
else filteredSymbols
else symbols
case LocalMethod(name, _) =>
for
owner <- withCompanionIfExtendsAnyVal(cls)
Expand All @@ -84,6 +97,30 @@ class Scala3Unpickler(
.filter(matchSymbol(method, _))
candidates.singleOptOrThrow(method.name)

def matchPrefix(prefix: String, owner: Symbol): Boolean =
if prefix.isEmpty then true
else if prefix.endsWith("$_") then
val stripped = prefix.stripSuffix("$$_")
matchPrefix(stripped, owner)
else if prefix.endsWith("$init$") then owner.isTerm && !owner.asTerm.isMethod
else
val regex = owner.name.toString match
case "$anonfun" => "\\$anonfun\\$\\d+$"
case name =>
Regex.quote(name)
+ (if owner.isLocal then "\\$\\d+" else "")
+ (if owner.isModuleClass then "\\$" else "")
+ "$"
regex.r.findFirstIn(prefix) match
case Some(suffix) =>
def enclosingDecl(owner: Symbol): DeclaringSymbol =
if owner.isInstanceOf[DeclaringSymbol] then owner.asInstanceOf[DeclaringSymbol]
else enclosingDecl(owner.owner)
val superOwner =
if owner.isLocal && !owner.isAnonFun then enclosingDecl(owner) else owner.owner
matchPrefix(prefix.stripSuffix(suffix).stripSuffix("$"), superOwner)
case None => false

def withCompanionIfExtendsAnyVal(cls: ClassSymbol): Seq[ClassSymbol] =
cls.companionClass match
case Some(companionClass) if companionClass.isSubclass(ctx.defn.AnyValClass) =>
Expand Down Expand Up @@ -416,8 +453,11 @@ class Scala3Unpickler(

extension (symbol: Symbol)
private def isTrait = symbol.isClass && symbol.asClass.isTrait
private def matchName(name: String) = symbol.name.toString == name
private def isAnonFun = symbol.name.toString() == "$anonfun"
private def matchName(name: String) =
symbol.name.toString == name
private def isLocal = symbol.owner.isTerm
private def isModuleClass = symbol.isClass && symbol.asClass.isModuleClass

extension [T <: Symbol](symbols: Seq[T])
def singleOrThrow(binaryName: String): T =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ import java.nio.file.Path
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.util.Properties
import scala.concurrent.duration.Duration
import scala.concurrent.duration.DurationInt

object Scala3UnpicklerStats:
class Scala3UnpicklerStats extends munit.FunSuite:
private val javaRuntime = JavaRuntime(Properties.jdkHome).get
private val javaRuntimeJars = javaRuntime match
case Java8(_, classJars, _) => classJars
case java9OrAbove: Java9OrAbove =>
java9OrAbove.classSystems.map(_.fileSystem.getPath("/modules", "java.base"))

def main(args: Array[String]): Unit =

val topLevelAndInnerClassCounter = new Counter()
val localClassCounter = new Counter()
val localMethodCounter = new Counter()
test("dotty stats"):
val localClassCounter = new Counter[ClassType]()
val topLevelOrInnerclassCounter = new Counter[ClassType]()
val localMethodCounter = new Counter[Method]()
val anonFunCounter = new Counter[Method]()

val jars = TestingResolver.fetch("org.scala-lang", "scala3-compiler_3", "3.3.0")
val unpickler = new Scala3Unpickler(jars.map(_.absolutePath).toArray ++ javaRuntimeJars, println, testMode = true)
Expand All @@ -36,19 +38,22 @@ object Scala3UnpicklerStats:
cls <- loadClasses(jars, "scala3-compiler_3-3.3.0")
clsSym <- cls match
case LocalClass(_, _, _) => processClass(unpickler, cls, localClassCounter)
case _ => processClass(unpickler, cls, topLevelAndInnerClassCounter)
case _ => processClass(unpickler, cls, topLevelOrInnerclassCounter)
// case AnonClass(_, _, _) => process(cls, anonClassCounter)
// case InnerClass(_, _) => process(cls, innerClassCounter)
// case _ => process(cls, topLevelClassCounter)
method <- cls.declaredMethods
methSym <- method match
case LocalMethod(_, _) => processMethod(unpickler, method, localMethodCounter)
// case LocalLazyInit(_, _, _) => process(method, localClassCounter)
case AnonFun(_) => processMethod(unpickler, method, anonFunCounter)
case LocalMethod(_) => processMethod(unpickler, method, localMethodCounter)
case _ => None
// case LocalLazyInit(_, _, _) => process(method, localClassCounter)
do ()
localClassCounter.printStatus("Local classes")
localMethodCounter.printStatus("Top level and inner classes")
localMethodCounter.printStatus("Local methods")
anonFunCounter.printStatus("anon fun")
topLevelOrInnerclassCounter.printStatus("topLevelOrInnerClass")

def loadClasses(jars: Seq[Library], jarName: String) =
val jar = jars.find(_.name == jarName).get
Expand All @@ -74,52 +79,58 @@ object Scala3UnpicklerStats:
println(s"classNames: ${classes.size}")
classes

def processClass(unpickler: Scala3Unpickler, cls: ClassType, counter: Counter): Option[ClassSymbol] =
def processClass(unpickler: Scala3Unpickler, cls: ClassType, counter: Counter[ClassType]): Option[ClassSymbol] =
try
val sym = unpickler.findClass(cls)
counter.addSuccess(cls.name)
counter.addSuccess(cls)
Some(sym)
catch
case AmbiguousException(e) =>
counter.addAmbiguous(cls.name)
counter.addAmbiguous(cls)
None
case NotFoundException(e) =>
counter.addNotFound(cls.name)
counter.addNotFound(cls)
None
case _ =>
case e =>
counter.exceptions += e.toString
None

def processMethod(unpickler: Scala3Unpickler, mthd: Method, counter: Counter): Option[TermSymbol] =
def processMethod(unpickler: Scala3Unpickler, mthd: Method, counter: Counter[Method]): Option[TermSymbol] =
try
val sym = unpickler.findSymbol(mthd)
sym match
case Some(t) =>
counter.addSuccess(mthd.name)
counter.addSuccess(mthd)
sym
case None =>
counter.addNotFound(mthd.name)
counter.addNotFound(mthd)
None
catch
case AmbiguousException(e) =>
counter.addAmbiguous(mthd.name)
counter.addAmbiguous(mthd)
None
case _ =>
case e =>
counter.exceptions += e.toString
None

class Counter:
val success: mutable.Set[String] = mutable.Set.empty[String]
var notFound: mutable.Set[String] = mutable.Set.empty[String]
var ambiguous: mutable.Set[String] = mutable.Set.empty[String]
override def munitTimeout: Duration = 2.minutes

class Counter[T]:
val success: mutable.Buffer[T] = mutable.Buffer.empty[T]
var notFound: mutable.Buffer[T] = mutable.Buffer.empty[T]
var ambiguous: mutable.Buffer[T] = mutable.Buffer.empty[T]
var exceptions: mutable.Buffer[String] = mutable.Buffer.empty[String]

def addSuccess(cls: String) = success.add(cls)
def addSuccess(cls: T) = success += cls

def addNotFound(cls: String) = notFound.add(cls)
def addNotFound(cls: T) = notFound += cls

def addAmbiguous(cls: String) = ambiguous.add(cls)
def addAmbiguous(cls: T) = ambiguous += cls

def printStatus(m: String) =
println(s"Status $m:")
println(s" - total is ${ambiguous.size + notFound.size + success.size}")
println(s" - success is ${success.size}")
println(s" - ambiguous is ${ambiguous.size}")
println(s" - notFound is ${notFound.size}")
println(s" - exceptions is ${exceptions.size}")
Original file line number Diff line number Diff line change
Expand Up @@ -415,20 +415,43 @@ abstract class Scala3UnpicklerTests(val scalaVersion: ScalaVersion) extends FunS
debuggee.assertFormat("example.A$", "example.A unapply(example.A x$1)", "A.unapply(A): A")
}

test("anonymous function") {
test("anonymous functions") {
val source =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| val f = (x: Int) => x + 3
| f(3)
| }
|}
|class A :
| class B :
| def m =
| List(true).map(x => x.toString + 1)
| val f: Int => String = x => ""
| def m =
| List("").map(x => x + 1)
|""".stripMargin
val debuggee = TestingDebuggee.mainClass(source, "example.Main", scalaVersion)
// TODO fix: it should find the symbol f by traversing the tree of object Main
debuggee.assertFind("example.Main$", "int $anonfun$1(int x)")
if isScala30 then
debuggee.assertFormat(
"example.A",
"java.lang.String m$$anonfun$2(boolean x)",
"A.B.m.$anonfun(x: Boolean): String"
)
debuggee.assertFormat("example.A", "java.lang.String $anonfun$1(int x)", "A.B.m.f.$anonfun(x: Int): String")
debuggee.assertFormat(
"example.A",
"java.lang.String m$$anonfun$1(java.lang.String x)",
"A.m.$anonfun(x: String): String"
)
else
debuggee.assertFormat(
"example.A",
"java.lang.String m$$anonfun$1(boolean x)",
"A.B.m.$anonfun(x: Boolean): String"
)
debuggee.assertFormat("example.A", "java.lang.String $anonfun$1(int x)", "A.B.m.f.$anonfun(x: Int): String")
debuggee.assertFormat(
"example.A",
"java.lang.String m$$anonfun$2(java.lang.String x)",
"A.m.$anonfun(x: String): String"
)
}

test("this.type") {
Expand Down
Binary file added sbt-launch.jar
Binary file not shown.

0 comments on commit fa78473

Please sign in to comment.