Skip to content

Commit

Permalink
Merge pull request #709 from adpi2/decoder-perf
Browse files Browse the repository at this point in the history
[Scala 3 binary decoder] Cache decoded symbols
  • Loading branch information
adpi2 committed May 8, 2024
2 parents 2862057 + b5e6c95 commit e8a4b8d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 29 deletions.
Expand Up @@ -60,7 +60,9 @@ object StackTraceProvider {
if (config.stepFilters.skipClassLoading) list = ClassLoadingFilter +: list
if (config.stepFilters.skipRuntimeClasses) list = RuntimeStepFilter(debuggee.scalaVersion) +: list
if (config.stepFilters.skipForwardersAndAccessors)
list = ScalaDecoder(debuggee, tools, logger, config.testMode) +: list
list = TimeUtils.logTime(logger, "Initialized Scala 3 decoder") {
ScalaDecoder(debuggee, tools, logger, config.testMode)
} +: list
list
}

Expand Down
Expand Up @@ -21,7 +21,12 @@ object BinaryDecoder:
val ctx = Context.initialize(classpath)
new BinaryDecoder(using ctx)

final class BinaryDecoder(using Context, ThrowOrWarn):
def cached(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder =
val classpath = CustomClasspath(ClasspathLoaders.read(classEntries.toList))
val ctx = Context.initialize(classpath)
new CachedBinaryDecoder(using ctx)

class BinaryDecoder(using Context, ThrowOrWarn):
private given defn: Definitions = Definitions()

def decode(cls: binary.ClassType): DecodedClass =
Expand Down Expand Up @@ -799,7 +804,16 @@ final class BinaryDecoder(using Context, ThrowOrWarn):
private def collectLiftedTrees[S](owner: Symbol, sourceLines: Option[binary.SourceLines])(
matcher: PartialFunction[LiftedTree[?], LiftedTree[S]]
): Seq[LiftedTree[S]] =
LiftedTreeCollector.collect(owner)(matcher).filter(tree => sourceLines.forall(matchLines(tree, _)))
val recursiveMatcher = new PartialFunction[LiftedTree[?], LiftedTree[S]]:
override def apply(tree: LiftedTree[?]): LiftedTree[S] = tree.asInstanceOf[LiftedTree[S]]
override def isDefinedAt(tree: LiftedTree[?]): Boolean = tree match
case InlinedFromArg(underlying, _, _) => isDefinedAt(underlying)
case InlinedFromDef(underlying, _) => isDefinedAt(underlying)
case _ => matcher.isDefinedAt(tree)
collectAllLiftedTrees(owner).collect(recursiveMatcher).filter(tree => sourceLines.forall(matchLines(tree, _)))

protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] =
LiftedTreeCollector.collect(owner)

private def matchLines(liftedFun: LiftedTree[?], sourceLines: binary.SourceLines): Boolean =
// we use endsWith instead of == because of tasty-query#434
Expand Down
@@ -0,0 +1,23 @@
package ch.epfl.scala.debugadapter.internal.stacktrace

import ch.epfl.scala.debugadapter.internal.binary
import ch.epfl.scala.debugadapter.internal.binary.Method
import ch.epfl.scala.debugadapter.internal.binary.SignedName
import tastyquery.Contexts.*
import tastyquery.Symbols.*

import scala.collection.concurrent.TrieMap

class CachedBinaryDecoder(using Context, ThrowOrWarn) extends BinaryDecoder:
private val classCache: TrieMap[String, DecodedClass] = TrieMap.empty
private val methodCache: TrieMap[(String, SignedName), DecodedMethod] = TrieMap.empty
private val liftedTreesCache: TrieMap[Symbol, Seq[LiftedTree[?]]] = TrieMap.empty

override def decode(cls: binary.ClassType): DecodedClass =
classCache.getOrElseUpdate(cls.name, super.decode(cls))

override def decode(method: Method): DecodedMethod =
methodCache.getOrElseUpdate((method.declaringClass.name, method.signedName), super.decode(method))

override protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] =
liftedTreesCache.getOrElseUpdate(owner, super.collectAllLiftedTrees(owner))
Expand Up @@ -17,46 +17,37 @@ import tastyquery.Exceptions.NonMethodReferenceException
* and compute the capture.
*/
object LiftedTreeCollector:
def collect[S](sym: Symbol)(matcher: PartialFunction[LiftedTree[?], LiftedTree[S]])(using
Context,
Definitions,
ThrowOrWarn
): Seq[LiftedTree[S]] =
val collector = LiftedTreeCollector[S](sym, matcher)
def collect(sym: Symbol)(using Context, Definitions, ThrowOrWarn): Seq[LiftedTree[?]] =
val collector = LiftedTreeCollector(sym)
sym.tree.toSeq.flatMap(collector.collect)

class LiftedTreeCollector[S] private (root: Symbol, matcher: PartialFunction[LiftedTree[?], LiftedTree[S]])(using
Context,
Definitions,
ThrowOrWarn
):
private val inlinedTrees = mutable.Map.empty[TermSymbol, Seq[LiftedTree[S]]]
class LiftedTreeCollector private (root: Symbol)(using Context, Definitions, ThrowOrWarn):
private val inlinedTrees = mutable.Map.empty[TermSymbol, Seq[LiftedTree[?]]]
private var owner = root

def collect(tree: Tree): Seq[LiftedTree[S]] =
val buffer = mutable.Buffer.empty[LiftedTree[S]]
def collect(tree: Tree): Seq[LiftedTree[?]] =
val buffer = mutable.Buffer.empty[LiftedTree[?]]

object Traverser extends TreeTraverser:
override def traverse(tree: Tree): Unit =
// register lifted funs
tree match
case tree: DefDef if tree.symbol.isLocal => registerLiftedFun(LocalDef(tree))
case tree: DefDef if tree.symbol.isLocal => buffer += LocalDef(tree)
case tree: ValDef if tree.symbol.isLocal && tree.symbol.isModuleOrLazyVal =>
registerLiftedFun(LocalLazyVal(tree))
case tree: ClassDef if tree.symbol.isLocal => registerLiftedFun(LocalClass(tree))
case tree: Lambda => registerLiftedFun(LambdaTree(tree))
case tree: Try => registerLiftedFun(LiftedTry(owner, tree))
buffer += LocalLazyVal(tree)
case tree: ClassDef if tree.symbol.isLocal => buffer += LocalClass(tree)
case tree: Lambda => buffer += LambdaTree(tree)
case tree: Try => buffer += LiftedTry(owner, tree)
case tree: Apply =>
for
symbol <- tree.safeFunSymbol
methodType <- tree.safeMethodType
do
val paramTypesAndArgs = methodType.paramTypes.zip(tree.args)
for case (byNameTpe: ByNameType, arg) <- paramTypesAndArgs do
registerLiftedFun(ByNameArg(owner, arg, byNameTpe.resultType, symbol.isInline))
buffer += ByNameArg(owner, arg, byNameTpe.resultType, symbol.isInline)
if owner.isClass && symbol.isConstructor then
for (paramTpe, arg) <- paramTypesAndArgs do
registerLiftedFun(ConstructorArg(owner.asClass, arg, paramTpe))
for (paramTpe, arg) <- paramTypesAndArgs do buffer += ConstructorArg(owner.asClass, arg, paramTpe)
case _ => ()

// recurse
Expand All @@ -82,15 +73,13 @@ class LiftedTreeCollector[S] private (root: Symbol, matcher: PartialFunction[Lif
case tree: (StatementTree | Template | CaseDef) => super.traverse(tree)
case _ => ()

def registerLiftedFun(tree: LiftedTree[?]): Unit =
matcher.lift(tree).foreach(buffer += _)
end Traverser

Traverser.traverse(tree)
buffer.toSeq
end collect

private def collectInlineDef(symbol: TermSymbol): Seq[LiftedTree[S]] =
private def collectInlineDef(symbol: TermSymbol): Seq[LiftedTree[?]] =
inlinedTrees(symbol) = Seq.empty // break recursion
symbol.tree.flatMap(extractRHS).toSeq.flatMap(collect)

Expand Down
Expand Up @@ -23,7 +23,7 @@ class Scala3DecoderBridge(
warnLogger: Consumer[String],
testMode: Boolean
):
private val decoder: BinaryDecoder = BinaryDecoder(classEntries)(
private val decoder: BinaryDecoder = BinaryDecoder.cached(classEntries)(
// make it quiet, or it would be too verbose when things go wrong
using ThrowOrWarn(_ => (), testMode)
)
Expand Down

0 comments on commit e8a4b8d

Please sign in to comment.