Skip to content

Commit

Permalink
Merge pull request #562 from scalacenter/anonClass
Browse files Browse the repository at this point in the history
Anon class
  • Loading branch information
adpi2 committed Aug 9, 2023
2 parents 5e8b0e6 + ec9e4cc commit 28cc8ca
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ trait Method:
def returnType: Option[Type]
def returnTypeName: String
def sourceLines: Seq[Int]
def isBridge: Boolean

def isExtensionMethod: Boolean = name.endsWith("$extension")
def isTraitInitializer: Boolean = name == "$init$"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ class JavaReflectConstructor(constructor: Constructor[?], val sourceLines: Seq[I

override def name: String = "<init>"

override def isBridge: Boolean = false

override def toString: String = constructor.toString
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ class JavaReflectMethod(method: Method, val sourceLines: Seq[Int]) extends binar
override def name: String = method.getName

override def toString: String = method.toString

override def isBridge: Boolean = method.isBridge
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ class JdiMethod(val obj: Any) extends JavaReflection(obj, "com.sun.jdi.Method")
if allDistinctLines.size > 1 then Seq(allDistinctLines.min, allDistinctLines.max)
else allDistinctLines

override def isBridge: Boolean = invokeMethod("isBridge")

private def allLineLocations: Seq[JdiLocation] =
invokeMethod[ju.List[Any]]("allLinesLocations").asScala.map(JdiLocation.apply(_)).toSeq
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package ch.epfl.scala.debugadapter.internal.stacktrace

import tastyquery.Contexts.Context
import tastyquery.Names.*

class Definitions(using ctx: Context):
val scalaPackage = ctx.defn.scalaPackage
val scalaRuntimePackage = scalaPackage.getPackageDecl(SimpleName("runtime")).get
val javaPackage = ctx.defn.RootPackage.getPackageDecl(SimpleName("java")).get
val javaIOPackage = javaPackage.getPackageDecl(SimpleName("io")).get
val partialFunction = scalaPackage.getDecl(typeName("PartialFunction")).get.asClass
val abstractPartialFunction = scalaRuntimePackage.getDecl(typeName("AbstractPartialFunction")).get.asClass
val serializable = javaIOPackage.getDecl(typeName("Serializable")).get.asClass
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@ object LocalClass:
.filter(xs => xs(1) != "anon")
.map(xs => (xs(0), xs(1), Option(xs(2)).map(_.stripPrefix("$")).filter(_.nonEmpty)))

object AnonClass:
def unapply(cls: binary.ClassType): Option[(String, Option[String])] =
val decodedClassName = NameTransformer.decode(cls.name.split('.').last)
unapply(decodedClassName)
def unapply(decodedClassName: String): Option[(String, Option[String])] =
"(.+)\\$\\$anon\\$\\d+(\\$.*)?".r
.unapplySeq(NameTransformer.decode(decodedClassName))
.map(xs => (xs(0), Option(xs(1)).map(_.stripPrefix("$")).filter(_.nonEmpty)))

object InnerClass:
def unapply(cls: binary.ClassType): Option[String] =
val decodedClassName = NameTransformer.decode(cls.name.split('.').last)
unapply(decodedClassName)

def unapply(decodedClassName: String): Option[String] =
"(.+)\\$(.+)".r
.unapplySeq(NameTransformer.decode(decodedClassName))
.map(_ => decodedClassName)

object LazyInit:
def unapply(method: binary.Method): Option[String] =
val lazyInit = "(.*)\\$lzyINIT\\d+".r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import ch.epfl.scala.debugadapter.internal.binary
import ch.epfl.scala.debugadapter.internal.jdi.JdiMethod
import tastyquery.Contexts
import tastyquery.Contexts.Context
import tastyquery.Definitions
import tastyquery.Flags
import tastyquery.Names.*
import tastyquery.Signatures.*
Expand All @@ -29,6 +28,7 @@ class Scala3Unpickler(
):
private val classpath = ClasspathLoaders.read(classpaths.toList)
private given ctx: Context = Contexts.init(classpath)
private val defn = new Definitions

private def warn(msg: String): Unit = warnLogger.accept(msg)

Expand Down Expand Up @@ -61,41 +61,48 @@ class Scala3Unpickler(

def findSymbol(method: binary.Method): Option[TermSymbol] =
val cls = findClass(method.declaringClass, method.isExtensionMethod)
val candidates = method match
case LocalLazyInit(name, _) =>
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case t: TermSymbol if (t.isLazyVal || t.isModuleVal) && t.matchName(name) => t
}
yield term
case AnonFun(prefix) =>
val symbols =
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case t: TermSymbol if t.isAnonFun && matchSignature(method, t) => t
}
yield term
if symbols.size > 1 && prefix.nonEmpty then
val filteredSymbols = symbols.filter(s => matchPrefix(prefix, s.owner))
if filteredSymbols.size == 0 then symbols
else filteredSymbols
else symbols
case LocalMethod(name, _) =>
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case t: TermSymbol if t.matchName(name) && matchSignature(method, t) => t
}
yield term
case LazyInit(name) =>
cls.declarations.collect { case t: TermSymbol if t.isLazyVal && t.matchName(name) => t }
case _ =>
cls.declarations
.collect { case sym: TermSymbol => sym }
.filter(matchSymbol(method, _))
candidates.singleOptOrThrow(method.name)
cls match
case term: TermSymbol =>
if method.declaringClass.superclass.get.name == "scala.runtime.AbstractPartialFunction" then
Option.when(!method.isBridge)(term)
else Option.when(!method.isBridge && matchSignature(method, term))(term)
case cls: ClassSymbol =>
val candidates = method match
case LocalLazyInit(name, _) =>
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case (t: TermSymbol, None) if (t.isLazyVal || t.isModuleVal) && t.matchName(name) => t
}
yield term
case AnonFun(prefix) =>
val x =
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case (t: TermSymbol, None) if t.isAnonFun && matchSignature(method, t) => t
}
yield term
if x.size > 1 && prefix.nonEmpty then
val y = x.filter(s => matchPrefix(prefix, s.owner))
if y.size == 0 then x
else y
else x
case LocalMethod(name, _) =>
for
owner <- withCompanionIfExtendsAnyVal(cls)
term <- collectLocalSymbols(owner) {
case (t: TermSymbol, None) if t.matchName(name) && matchSignature(method, t) => t
}
yield term
case LazyInit(name) =>
cls.declarations.collect { case t: TermSymbol if t.isLazyVal && t.matchName(name) => t }
case _ =>
cls.declarations
.collect { case sym: TermSymbol => sym }
.filter(matchSymbol(method, _))
candidates.singleOptOrThrow(method.name)
case _ => None

def matchPrefix(prefix: String, owner: Symbol): Boolean =
if prefix.isEmpty then true
Expand All @@ -104,7 +111,7 @@ class Scala3Unpickler(
matchPrefix(stripped, owner)
else if prefix.endsWith("$init$") then owner.isTerm && !owner.asTerm.isMethod
else
val regex = owner.name.toString match
val regex = owner.nameStr match
case "$anonfun" => "\\$anonfun\\$\\d+$"
case name =>
Regex.quote(name)
Expand All @@ -127,14 +134,19 @@ class Scala3Unpickler(
Seq(cls, companionClass)
case _ => Seq(cls)

def collectLocalSymbols[S <: Symbol](cls: ClassSymbol)(partialF: PartialFunction[Symbol, S]): Seq[S] =
def collectLocalSymbols[S <: Symbol](cls: ClassSymbol)(
partialF: PartialFunction[(Symbol, Option[ClassSymbol]), S]
): Seq[S] =
val f = partialF.lift.andThen(_.toSeq)

def collectSymbols(tree: Tree): Seq[S] =
tree.walkTree {
case ValDef(_, _, _, symbol) if symbol.isLocal && (symbol.isLazyVal || symbol.isModuleVal) => f(symbol)
case DefDef(_, _, _, _, symbol) if symbol.isLocal => f(symbol)
case ClassDef(_, _, symbol) if symbol.isLocal => f(symbol)
case ValDef(_, _, _, symbol) if symbol.isLocal && (symbol.isLazyVal || symbol.isModuleVal) => f((symbol, None))
case DefDef(_, _, _, _, symbol) if symbol.isLocal => f(symbol, None)
case ClassDef(_, _, symbol) if symbol.isLocal => f(symbol, None)
case lambda: Lambda =>
val sym = lambda.meth.asInstanceOf[TermReferenceTree].symbol
f(sym, Some(lambda.samClassSymbol))
case _ => Seq.empty
}(_ ++ _, Seq.empty)

Expand Down Expand Up @@ -257,7 +269,7 @@ class Scala3Unpickler(
case owner: PackageSymbol => ""
val symName = sym.name match
case DefaultGetterName(termName, num) => s"${termName.toString()}.<default ${num + 1}>"
case _ => sym.name.toString()
case _ => sym.nameStr

if prefix.isEmpty then symName else s"$prefix.$symName"

Expand Down Expand Up @@ -297,7 +309,7 @@ class Scala3Unpickler(
case p: PackageRef => p.fullyQualifiedName.toString == "scala"
case _ => false

def findClass(cls: binary.ClassType, isExtensionMethod: Boolean = false): ClassSymbol =
def findClass(cls: binary.ClassType, isExtensionMethod: Boolean = false): Symbol =
val javaParts = cls.name.split('.')
val packageNames = javaParts.dropRight(1).toList.map(SimpleName.apply)
val packageSym =
Expand All @@ -306,38 +318,74 @@ class Scala3Unpickler(
else ctx.defn.EmptyPackage
val decodedClassName = NameTransformer.decode(javaParts.last)
val allSymbols = decodedClassName match
case AnonClass(declaringClassName, remaining) =>
val WithLocalPart = "(.+)\\$(.+)\\$\\d+".r
val decl = declaringClassName match
case WithLocalPart(decl, _) => decl.stripSuffix("$")
case decl => decl
findLocalClasses(cls, packageSym, decl, "$anon", remaining)
case LocalClass(declaringClassName, localClassName, remaining) =>
val owners = findSymbolsRecursively(packageSym, declaringClassName)
val localClasses = owners.flatMap(findLocalClasses(_, localClassName, cls))
remaining match
case None => localClasses
case Some(remaining) => localClasses.flatMap(findSymbolsRecursively(_, remaining))
findLocalClasses(cls, packageSym, declaringClassName, localClassName, remaining)
case _ => findSymbolsRecursively(packageSym, decodedClassName)
if cls.isObject && !isExtensionMethod
then allSymbols.filter(_.isModuleClass).singleOrThrow(cls.name)
else allSymbols.filter(!_.isModuleClass).singleOrThrow(cls.name)

private def findLocalClasses(
cls: binary.ClassType,
packageSym: PackageSymbol,
declaringClassName: String,
localClassName: String,
remaining: Option[String]
): Seq[Symbol] =
val owners = findSymbolsRecursively(packageSym, declaringClassName)
remaining match
case None => owners.flatMap(findLocalClasses(_, localClassName, Some(cls)))
case Some(remaining) =>
val localClasses = owners
.flatMap(findLocalClasses(_, localClassName, None))
.filter(_.isClass)
localClasses.flatMap(s => findSymbolsRecursively(s.asClass, remaining))

private def findSymbolsRecursively(owner: DeclaringSymbol, decodedName: String): Seq[ClassSymbol] =
owner.declarations
.collect { case sym: ClassSymbol => sym }
.flatMap { sym =>
val Symbol = s"${Regex.quote(sym.name.toString)}\\$$?(.*)".r
val Symbol = s"${Regex.quote(sym.nameStr)}\\$$?(.*)".r
decodedName match
case Symbol(remaining) =>
if remaining.isEmpty then Some(sym)
else findSymbolsRecursively(sym, remaining)
case _ => None
}

private def findLocalClasses(owner: ClassSymbol, name: String, cls: binary.ClassType): Seq[ClassSymbol] =
val superClassAndInterfaces = (cls.superclass.toSeq ++ cls.interfaces).map(findClass(_)).toSet

def matchesParents(classSymbol: ClassSymbol): Boolean =
if classSymbol.isEnum then superClassAndInterfaces == classSymbol.parentClasses.toSet + ctx.defn.ProductClass
else if cls.isInterface then superClassAndInterfaces == classSymbol.parentClasses.filter(_.isTrait).toSet
else superClassAndInterfaces == classSymbol.parentClasses.toSet

collectLocalSymbols(owner) { case cls: ClassSymbol if cls.matchName(name) && matchesParents(cls) => cls }
private def findLocalClasses(owner: ClassSymbol, name: String, javaClass: Option[binary.ClassType]): Seq[Symbol] =
javaClass match
case Some(cls) =>
val superClassAndInterfaces = (cls.superclass.toSeq ++ cls.interfaces).map(findClass(_)).toSet

def matchParents(classSymbol: ClassSymbol): Boolean =
if classSymbol.isEnum then superClassAndInterfaces == classSymbol.parentClasses.toSet + ctx.defn.ProductClass
else if cls.isInterface then superClassAndInterfaces == classSymbol.parentClasses.filter(_.isTrait).toSet
else if classSymbol.isAnonClass then classSymbol.parentClasses.forall(superClassAndInterfaces.contains)
else superClassAndInterfaces == classSymbol.parentClasses.toSet

def matchSamClass(samClass: ClassSymbol): Boolean =
if samClass == defn.partialFunction then
superClassAndInterfaces.size == 2 &&
superClassAndInterfaces.exists(_ == defn.abstractPartialFunction) &&
superClassAndInterfaces.exists(_ == defn.serializable)
else superClassAndInterfaces.contains(samClass)

collectLocalSymbols(owner) {
case (cls: ClassSymbol, None) if cls.matchName(name) && matchParents(cls) => cls
case (lambda: TermSymbol, Some(tpt)) if matchSamClass(tpt) => lambda
}
case _ =>
collectLocalSymbols(owner) {
case (cls: ClassSymbol, None) if cls.matchName(name) => cls
case (lambda: TermSymbol, Some(tpt)) => lambda
}

private def matchSymbol(method: binary.Method, symbol: TermSymbol): Boolean =
matchTargetName(method, symbol) && (method.isTraitInitializer || matchSignature(method, symbol))
Expand Down Expand Up @@ -446,27 +494,5 @@ class Scala3Unpickler(
private def skip(symbol: TermSymbol): Boolean =
(symbol.isGetterOrSetter && !symbol.isLazyValInTrait) || symbol.isSynthetic

extension (symbol: TermSymbol)
private def isGetterOrSetter = !symbol.isMethod || symbol.isSetter
private def isLazyValInTrait: Boolean = symbol.owner.isTrait && symbol.isLazyVal
private def isLazyVal: Boolean = symbol.kind == TermSymbolKind.LazyVal

extension (symbol: Symbol)
private def isTrait = symbol.isClass && symbol.asClass.isTrait
private def isAnonFun = symbol.name.toString() == "$anonfun"
private def matchName(name: String) =
symbol.name.toString == name
private def isLocal = symbol.owner.isTerm
private def isModuleClass = symbol.isClass && symbol.asClass.isModuleClass

extension [T <: Symbol](symbols: Seq[T])
def singleOrThrow(binaryName: String): T =
singleOptOrThrow(binaryName)
.getOrElse(throw new NotFoundException(s"Cannot find Scala symbol of $binaryName"))

def singleOptOrThrow(binaryName: String): Option[T] =
if symbols.size > 1 then throw new AmbiguousException(s"Found ${symbols.size} matching symbols for $binaryName")
else symbols.headOption

case class AmbiguousException(m: String) extends Exception
case class NotFoundException(m: String) extends Exception
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ch.epfl.scala.debugadapter.internal.stacktrace

import tastyquery.Symbols.*
import tastyquery.Modifiers.TermSymbolKind

extension (symbol: Symbol)
def isTrait = symbol.isClass && symbol.asClass.isTrait
def isAnonFun = symbol.nameStr == "$anonfun"
def isAnonClass = symbol.nameStr == "$anon"
def matchName(name: String) =
symbol.nameStr == name
def isLocal = symbol.owner.isTerm
def isModuleClass = symbol.isClass && symbol.asClass.isModuleClass
def nameStr = symbol.name.toString

extension (symbol: TermSymbol)
private def isGetterOrSetter = !symbol.isMethod || symbol.isSetter
private def isLazyValInTrait: Boolean = symbol.owner.isTrait && symbol.isLazyVal
private def isLazyVal: Boolean = symbol.kind == TermSymbolKind.LazyVal

extension [T <: Symbol](symbols: Seq[T])
def singleOrThrow(binaryName: String): T =
singleOptOrThrow(binaryName)
.getOrElse(throw new NotFoundException(s"Cannot find Scala symbol of $binaryName"))

def singleOptOrThrow(binaryName: String): Option[T] =
if symbols.size > 1 then throw new AmbiguousException(s"Found ${symbols.size} matching symbols for $binaryName")
else symbols.headOption
Loading

0 comments on commit 28cc8ca

Please sign in to comment.