diff --git a/modules/core/src/main/scala/ch/epfl/scala/debugadapter/internal/StackTraceProvider.scala b/modules/core/src/main/scala/ch/epfl/scala/debugadapter/internal/StackTraceProvider.scala index 609ea559b..edcb429a3 100644 --- a/modules/core/src/main/scala/ch/epfl/scala/debugadapter/internal/StackTraceProvider.scala +++ b/modules/core/src/main/scala/ch/epfl/scala/debugadapter/internal/StackTraceProvider.scala @@ -60,7 +60,9 @@ object StackTraceProvider { if (config.stepFilters.skipClassLoading) list = ClassLoadingFilter +: list if (config.stepFilters.skipRuntimeClasses) list = RuntimeStepFilter(debuggee.scalaVersion) +: list if (config.stepFilters.skipForwardersAndAccessors) - list = ScalaDecoder(debuggee, tools, logger, config.testMode) +: list + list = TimeUtils.logTime(logger, "Initialized Scala 3 decoder") { + ScalaDecoder(debuggee, tools, logger, config.testMode) + } +: list list } diff --git a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/BinaryDecoder.scala b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/BinaryDecoder.scala index 61fac3082..d8e6d03ba 100644 --- a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/BinaryDecoder.scala +++ b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/BinaryDecoder.scala @@ -21,7 +21,12 @@ object BinaryDecoder: val ctx = Context.initialize(classpath) new BinaryDecoder(using ctx) -final class BinaryDecoder(using Context, ThrowOrWarn): + def cached(classEntries: Seq[Path])(using ThrowOrWarn): BinaryDecoder = + val classpath = CustomClasspath(ClasspathLoaders.read(classEntries.toList)) + val ctx = Context.initialize(classpath) + new CachedBinaryDecoder(using ctx) + +class BinaryDecoder(using Context, ThrowOrWarn): private given defn: Definitions = Definitions() def decode(cls: binary.ClassType): DecodedClass = @@ -799,7 +804,16 @@ final class BinaryDecoder(using Context, ThrowOrWarn): private def collectLiftedTrees[S](owner: Symbol, sourceLines: Option[binary.SourceLines])( matcher: PartialFunction[LiftedTree[?], LiftedTree[S]] ): Seq[LiftedTree[S]] = - LiftedTreeCollector.collect(owner)(matcher).filter(tree => sourceLines.forall(matchLines(tree, _))) + val recursiveMatcher = new PartialFunction[LiftedTree[?], LiftedTree[S]]: + override def apply(tree: LiftedTree[?]): LiftedTree[S] = tree.asInstanceOf[LiftedTree[S]] + override def isDefinedAt(tree: LiftedTree[?]): Boolean = tree match + case InlinedFromArg(underlying, _, _) => isDefinedAt(underlying) + case InlinedFromDef(underlying, _) => isDefinedAt(underlying) + case _ => matcher.isDefinedAt(tree) + collectAllLiftedTrees(owner).collect(recursiveMatcher).filter(tree => sourceLines.forall(matchLines(tree, _))) + + protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = + LiftedTreeCollector.collect(owner) private def matchLines(liftedFun: LiftedTree[?], sourceLines: binary.SourceLines): Boolean = // we use endsWith instead of == because of tasty-query#434 diff --git a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/CachedBinaryDecoder.scala b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/CachedBinaryDecoder.scala new file mode 100644 index 000000000..2f6e08e94 --- /dev/null +++ b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/CachedBinaryDecoder.scala @@ -0,0 +1,23 @@ +package ch.epfl.scala.debugadapter.internal.stacktrace + +import ch.epfl.scala.debugadapter.internal.binary +import ch.epfl.scala.debugadapter.internal.binary.Method +import ch.epfl.scala.debugadapter.internal.binary.SignedName +import tastyquery.Contexts.* +import tastyquery.Symbols.* + +import scala.collection.concurrent.TrieMap + +class CachedBinaryDecoder(using Context, ThrowOrWarn) extends BinaryDecoder: + private val classCache: TrieMap[String, DecodedClass] = TrieMap.empty + private val methodCache: TrieMap[(String, SignedName), DecodedMethod] = TrieMap.empty + private val liftedTreesCache: TrieMap[Symbol, Seq[LiftedTree[?]]] = TrieMap.empty + + override def decode(cls: binary.ClassType): DecodedClass = + classCache.getOrElseUpdate(cls.name, super.decode(cls)) + + override def decode(method: Method): DecodedMethod = + methodCache.getOrElseUpdate((method.declaringClass.name, method.signedName), super.decode(method)) + + override protected def collectAllLiftedTrees(owner: Symbol): Seq[LiftedTree[?]] = + liftedTreesCache.getOrElseUpdate(owner, super.collectAllLiftedTrees(owner)) diff --git a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/LiftedTreeCollector.scala b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/LiftedTreeCollector.scala index f85dc7271..a400e75a3 100644 --- a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/LiftedTreeCollector.scala +++ b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/LiftedTreeCollector.scala @@ -17,35 +17,27 @@ import tastyquery.Exceptions.NonMethodReferenceException * and compute the capture. */ object LiftedTreeCollector: - def collect[S](sym: Symbol)(matcher: PartialFunction[LiftedTree[?], LiftedTree[S]])(using - Context, - Definitions, - ThrowOrWarn - ): Seq[LiftedTree[S]] = - val collector = LiftedTreeCollector[S](sym, matcher) + def collect(sym: Symbol)(using Context, Definitions, ThrowOrWarn): Seq[LiftedTree[?]] = + val collector = LiftedTreeCollector(sym) sym.tree.toSeq.flatMap(collector.collect) -class LiftedTreeCollector[S] private (root: Symbol, matcher: PartialFunction[LiftedTree[?], LiftedTree[S]])(using - Context, - Definitions, - ThrowOrWarn -): - private val inlinedTrees = mutable.Map.empty[TermSymbol, Seq[LiftedTree[S]]] +class LiftedTreeCollector private (root: Symbol)(using Context, Definitions, ThrowOrWarn): + private val inlinedTrees = mutable.Map.empty[TermSymbol, Seq[LiftedTree[?]]] private var owner = root - def collect(tree: Tree): Seq[LiftedTree[S]] = - val buffer = mutable.Buffer.empty[LiftedTree[S]] + def collect(tree: Tree): Seq[LiftedTree[?]] = + val buffer = mutable.Buffer.empty[LiftedTree[?]] object Traverser extends TreeTraverser: override def traverse(tree: Tree): Unit = // register lifted funs tree match - case tree: DefDef if tree.symbol.isLocal => registerLiftedFun(LocalDef(tree)) + case tree: DefDef if tree.symbol.isLocal => buffer += LocalDef(tree) case tree: ValDef if tree.symbol.isLocal && tree.symbol.isModuleOrLazyVal => - registerLiftedFun(LocalLazyVal(tree)) - case tree: ClassDef if tree.symbol.isLocal => registerLiftedFun(LocalClass(tree)) - case tree: Lambda => registerLiftedFun(LambdaTree(tree)) - case tree: Try => registerLiftedFun(LiftedTry(owner, tree)) + buffer += LocalLazyVal(tree) + case tree: ClassDef if tree.symbol.isLocal => buffer += LocalClass(tree) + case tree: Lambda => buffer += LambdaTree(tree) + case tree: Try => buffer += LiftedTry(owner, tree) case tree: Apply => for symbol <- tree.safeFunSymbol @@ -53,10 +45,9 @@ class LiftedTreeCollector[S] private (root: Symbol, matcher: PartialFunction[Lif do val paramTypesAndArgs = methodType.paramTypes.zip(tree.args) for case (byNameTpe: ByNameType, arg) <- paramTypesAndArgs do - registerLiftedFun(ByNameArg(owner, arg, byNameTpe.resultType, symbol.isInline)) + buffer += ByNameArg(owner, arg, byNameTpe.resultType, symbol.isInline) if owner.isClass && symbol.isConstructor then - for (paramTpe, arg) <- paramTypesAndArgs do - registerLiftedFun(ConstructorArg(owner.asClass, arg, paramTpe)) + for (paramTpe, arg) <- paramTypesAndArgs do buffer += ConstructorArg(owner.asClass, arg, paramTpe) case _ => () // recurse @@ -82,15 +73,13 @@ class LiftedTreeCollector[S] private (root: Symbol, matcher: PartialFunction[Lif case tree: (StatementTree | Template | CaseDef) => super.traverse(tree) case _ => () - def registerLiftedFun(tree: LiftedTree[?]): Unit = - matcher.lift(tree).foreach(buffer += _) end Traverser Traverser.traverse(tree) buffer.toSeq end collect - private def collectInlineDef(symbol: TermSymbol): Seq[LiftedTree[S]] = + private def collectInlineDef(symbol: TermSymbol): Seq[LiftedTree[?]] = inlinedTrees(symbol) = Seq.empty // break recursion symbol.tree.flatMap(extractRHS).toSeq.flatMap(collect) diff --git a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/Scala3DecoderBridge.scala b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/Scala3DecoderBridge.scala index 0e174fbad..bbc3601a9 100644 --- a/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/Scala3DecoderBridge.scala +++ b/modules/decoder/src/main/scala/ch/epfl/scala/debugadapter/internal/stacktrace/Scala3DecoderBridge.scala @@ -23,7 +23,7 @@ class Scala3DecoderBridge( warnLogger: Consumer[String], testMode: Boolean ): - private val decoder: BinaryDecoder = BinaryDecoder(classEntries)( + private val decoder: BinaryDecoder = BinaryDecoder.cached(classEntries)( // make it quiet, or it would be too verbose when things go wrong using ThrowOrWarn(_ => (), testMode) )