From 85cd09c325605e0dd0ab08ad59c09bc693e72de8 Mon Sep 17 00:00:00 2001 From: Tomasz Godzik Date: Wed, 1 Feb 2023 18:42:35 +0100 Subject: [PATCH] feature: Add support for running Native Works by doing the bare minimum DAP requires and forwars the logs from the build server --- .../metals/BuildServerConnection.scala | 11 ++ .../metals/ForwardingMetalsBuildClient.scala | 66 +++++-- .../internal/metals/MetalsLspService.scala | 1 + .../metals/codelenses/RunTestCodeLens.scala | 64 ++++--- .../internal/metals/debug/DebugProtocol.scala | 42 +++- .../internal/metals/debug/DebugProvider.scala | 108 +++++++++-- .../internal/metals/debug/DebugProxy.scala | 4 +- .../internal/metals/debug/DebugRunner.scala | 179 ++++++++++++++++++ .../scala/tests/scalacli/ScalaCliSuite.scala | 119 ++++++++++++ .../internal/metals/debug/RemoteServer.scala | 6 +- .../scala/tests/BaseCodeLensLspSuite.scala | 20 +- .../src/main/scala/tests/TestingClient.scala | 2 + .../src/main/scala/tests/TestingServer.scala | 54 ++++-- .../test/scala/tests/CodeLensLspSuite.scala | 1 + 14 files changed, 588 insertions(+), 89 deletions(-) create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/DebugRunner.scala diff --git a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala index fcc150eadca..7a65f500421 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/BuildServerConnection.scala @@ -294,6 +294,17 @@ class BuildServerConnection private ( } } + def buildTargetRun( + params: RunParams, + cancelPromise: Promise[Unit], + ): Future[RunResult] = { + val completableFuture = register(server => server.buildTargetRun(params)) + cancelPromise.future.foreach { _ => + completableFuture.cancel(true) + } + completableFuture.asScala + } + def buildTargetScalacOptions( params: ScalacOptionsParams ): Future[ScalacOptionsResult] = { diff --git a/metals/src/main/scala/scala/meta/internal/metals/ForwardingMetalsBuildClient.scala b/metals/src/main/scala/scala/meta/internal/metals/ForwardingMetalsBuildClient.scala index 59a9da7cf27..cb465de59b7 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ForwardingMetalsBuildClient.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ForwardingMetalsBuildClient.scala @@ -2,7 +2,7 @@ package scala.meta.internal.metals.clients.language import java.util.Collections import java.util.concurrent.ConcurrentHashMap -import java.{util => ju} +import java.util.concurrent.atomic.AtomicReference import scala.collection.concurrent.TrieMap import scala.concurrent.Promise @@ -23,12 +23,24 @@ import scala.meta.internal.metals.Timer import scala.meta.internal.tvp._ import scala.meta.io.AbsolutePath -import ch.epfl.scala.bsp4j._ import ch.epfl.scala.{bsp4j => b} import com.google.gson.JsonObject import org.eclipse.lsp4j.jsonrpc.services.JsonNotification import org.eclipse.{lsp4j => l} +/** + * Used to forward messages from the build server. Messages might + * be mixed if the server is sending messages as well as output from + * running. This hasn't been a problem yet, not perfect solution, + * but seems to work ok. + */ +trait LogForwarder { + def error(message: String): Unit = () + def warn(message: String): Unit = () + def info(message: String): Unit = () + def log(message: String): Unit = () +} + /** * A build client that forwards notifications from the build server to the language client. */ @@ -39,31 +51,39 @@ final class ForwardingMetalsBuildClient( clientConfig: ClientConfiguration, statusBar: StatusBar, time: Time, - didCompile: CompileReport => Unit, - onBuildTargetDidCompile: BuildTargetIdentifier => Unit, + didCompile: b.CompileReport => Unit, + onBuildTargetDidCompile: b.BuildTargetIdentifier => Unit, onBuildTargetDidChangeFunc: b.DidChangeBuildTarget => Unit, bspErrorHandler: BspErrorHandler, ) extends MetalsBuildClient with Cancelable { + private val forwarders = + new AtomicReference(List.empty[LogForwarder]) + + def registerLogForwarder( + logForwarder: LogForwarder + ): List[LogForwarder] = { + forwarders.getAndUpdate(_.prepended(logForwarder)) + } private case class Compilation( timer: Timer, - promise: Promise[CompileReport], + promise: Promise[b.CompileReport], isNoOp: Boolean, progress: TaskProgress = TaskProgress.empty, ) extends TreeViewCompilation { def progressPercentage = progress.percentage } - private val compilations = TrieMap.empty[BuildTargetIdentifier, Compilation] + private val compilations = TrieMap.empty[b.BuildTargetIdentifier, Compilation] private val hasReportedError = Collections.newSetFromMap( - new ConcurrentHashMap[BuildTargetIdentifier, java.lang.Boolean]() + new ConcurrentHashMap[b.BuildTargetIdentifier, java.lang.Boolean]() ) - val updatedTreeViews: ju.Set[BuildTargetIdentifier] = - ConcurrentHashSet.empty[BuildTargetIdentifier] + val updatedTreeViews: java.util.Set[b.BuildTargetIdentifier] = + ConcurrentHashSet.empty[b.BuildTargetIdentifier] - def buildHasErrors(buildTargetId: BuildTargetIdentifier): Boolean = { + def buildHasErrors(buildTargetId: b.BuildTargetIdentifier): Boolean = { buildTargets .buildTargetTransitiveDependencies(buildTargetId) .exists(hasReportedError.contains(_)) @@ -96,18 +116,22 @@ final class ForwardingMetalsBuildClient( def onBuildShowMessage(params: l.MessageParams): Unit = languageClient.showMessage(params) - def onBuildLogMessage(params: l.MessageParams): Unit = + def onBuildLogMessage(params: l.MessageParams): Unit = { params.getType match { case l.MessageType.Error => bspErrorHandler.onError(params.getMessage()) + forwarders.get().foreach(_.error(params.getMessage())) case l.MessageType.Warning => + forwarders.get().foreach(_.warn(params.getMessage())) scribe.warn(params.getMessage) case l.MessageType.Info => + forwarders.get().foreach(_.info(params.getMessage())) scribe.info(params.getMessage) case l.MessageType.Log => + forwarders.get().foreach(_.log(params.getMessage())) scribe.info(params.getMessage) } - + } def onBuildPublishDiagnostics(params: b.PublishDiagnosticsParams): Unit = { diagnostics.onBuildPublishDiagnostics(params) } @@ -119,9 +143,9 @@ final class ForwardingMetalsBuildClient( def onBuildTargetCompileReport(params: b.CompileReport): Unit = {} @JsonNotification("build/taskStart") - def buildTaskStart(params: TaskStartParams): Unit = { + def buildTaskStart(params: b.TaskStartParams): Unit = { params.getDataKind match { - case TaskStartDataKind.COMPILE_TASK => + case b.TaskStartDataKind.COMPILE_TASK => if ( params.getMessage != null && params.getMessage.startsWith("Compiling") ) { @@ -137,7 +161,7 @@ final class ForwardingMetalsBuildClient( compilations.remove(target).foreach(_.promise.cancel()) val name = info.getDisplayName - val promise = Promise[CompileReport]() + val promise = Promise[b.CompileReport]() val isNoOp = params.getMessage != null && params.getMessage.startsWith( "Start no-op compilation" @@ -157,9 +181,9 @@ final class ForwardingMetalsBuildClient( } @JsonNotification("build/taskFinish") - def buildTaskFinish(params: TaskFinishParams): Unit = { + def buildTaskFinish(params: b.TaskFinishParams): Unit = { params.getDataKind match { - case TaskFinishDataKind.COMPILE_REPORT => + case b.TaskFinishDataKind.COMPILE_REPORT => for { report <- params.asCompileReport compilation <- compilations.remove(report.getTarget) @@ -214,8 +238,8 @@ final class ForwardingMetalsBuildClient( } @JsonNotification("build/taskProgress") - def buildTaskProgress(params: TaskProgressParams): Unit = { - def buildTargetFromParams: Option[BuildTargetIdentifier] = + def buildTaskProgress(params: b.TaskProgressParams): Unit = { + def buildTargetFromParams: Option[b.BuildTargetIdentifier] = for { data <- Option(params.getData).collect { case o: JsonObject => o @@ -227,7 +251,7 @@ final class ForwardingMetalsBuildClient( if uriElement.isJsonPrimitive uri = uriElement.getAsJsonPrimitive if uri.isString - } yield new BuildTargetIdentifier(uri.getAsString) + } yield new b.BuildTargetIdentifier(uri.getAsString) params.getDataKind match { case "bloop-progress" => @@ -252,7 +276,7 @@ final class ForwardingMetalsBuildClient( def ongoingCompilations(): TreeViewCompilations = new TreeViewCompilations { - override def get(id: BuildTargetIdentifier) = compilations.get(id) + override def get(id: b.BuildTargetIdentifier) = compilations.get(id) override def isEmpty = compilations.isEmpty override def size = compilations.size override def buildTargets = compilations.keysIterator diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index 4f8b8286fc5..5ffe97f96b9 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -728,6 +728,7 @@ class MetalsLspService( testProvider, ) ) + buildClient.registerLogForwarder(debugProvider) private val scalafixProvider: ScalafixProvider = ScalafixProvider( buffers, diff --git a/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala b/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala index 903e9087e2a..6ef6311dfd3 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala @@ -64,8 +64,7 @@ final class RunTestCodeLens( val lenses = for { buildTargetId <- buildTargets.inverseSources(path) buildTarget <- buildTargets.info(buildTargetId) - // generate code lenses only for JVM based targets for Scala - if buildTarget.asScalaBuildTarget.forall( + isJVM = buildTarget.asScalaBuildTarget.forall( _.getPlatform == b.ScalaPlatform.JVM ) connection <- buildTargets.buildServerOf(buildTargetId) @@ -85,6 +84,7 @@ final class RunTestCodeLens( classes, distance, buildServerCanDebug, + isJVM, ) } else if (buildServerCanDebug || clientConfig.isRunProvider()) { codeLenses( @@ -94,6 +94,7 @@ final class RunTestCodeLens( distance, path, buildServerCanDebug, + isJVM, ) } else { Nil } @@ -160,6 +161,7 @@ final class RunTestCodeLens( Nil.asJava, ), buildServerCanDebug, + isJVM = true, ) else Nil @@ -177,6 +179,7 @@ final class RunTestCodeLens( distance: TokenEditDistance, path: AbsolutePath, buildServerCanDebug: Boolean, + isJVM: Boolean, ): Seq[l.CodeLens] = { for { occurrence <- textDocument.occurrences @@ -185,19 +188,19 @@ final class RunTestCodeLens( commands = { val main = classes.mainClasses .get(symbol) - .map(mainCommand(target, _, buildServerCanDebug)) + .map(mainCommand(target, _, buildServerCanDebug, isJVM)) .getOrElse(Nil) val tests = // Currently tests can only be run via DAP if (clientConfig.isDebuggingProvider() && buildServerCanDebug) - testClasses(target, classes, symbol) + testClasses(target, classes, symbol, isJVM) else Nil val fromAnnot = DebugProvider .mainFromAnnotation(occurrence, textDocument) .flatMap { symbol => classes.mainClasses .get(symbol) - .map(mainCommand(target, _, buildServerCanDebug)) + .map(mainCommand(target, _, buildServerCanDebug, isJVM)) } .getOrElse(Nil) val javaMains = @@ -225,6 +228,7 @@ final class RunTestCodeLens( classes: BuildTargetClasses.Classes, distance: TokenEditDistance, buildServerCanDebug: Boolean, + isJVM: Boolean, ): Seq[l.CodeLens] = { val scriptFileName = textDocument.uri.stripSuffix(".sc") @@ -234,7 +238,7 @@ final class RunTestCodeLens( val main = classes.mainClasses .get(expectedMainClass) - .map(mainCommand(target, _, buildServerCanDebug)) + .map(mainCommand(target, _, buildServerCanDebug, isJVM)) .getOrElse(Nil) val fromAnnotations = textDocument.occurrences.flatMap { occ => @@ -242,7 +246,7 @@ final class RunTestCodeLens( sym <- DebugProvider.mainFromAnnotation(occ, textDocument) cls <- classes.mainClasses.get(sym) range <- occurrenceRange(occ, distance) - } yield mainCommand(target, cls, buildServerCanDebug).map { cmd => + } yield mainCommand(target, cls, buildServerCanDebug, isJVM).map { cmd => new l.CodeLens(range, cmd, null) } }.flatten @@ -262,13 +266,14 @@ final class RunTestCodeLens( target: BuildTargetIdentifier, classes: BuildTargetClasses.Classes, symbol: String, + isJVM: Boolean, ): List[l.Command] = if (userConfig().testUserInterface == TestUserInterfaceKind.CodeLenses) classes.testClasses .get(symbol) .toList .flatMap(symbolInfo => - testCommand(target, symbolInfo.fullyQualifiedName) + testCommand(target, symbolInfo.fullyQualifiedName, isJVM) ) else Nil @@ -276,6 +281,7 @@ final class RunTestCodeLens( private def testCommand( target: b.BuildTargetIdentifier, className: String, + isJVM: Boolean, ): List[l.Command] = { val params = { val dataKind = b.TestParamsDataKind.SCALA_TEST_SUITES @@ -283,16 +289,22 @@ final class RunTestCodeLens( sessionParams(target, dataKind, data) } - List( - command("test", StartRunSession, params), - command("debug test", StartDebugSession, params), - ) + if (isJVM) + List( + command("test", StartRunSession, params), + command("debug test", StartDebugSession, params), + ) + else + List( + command("test", StartRunSession, params) + ) } private def mainCommand( target: b.BuildTargetIdentifier, main: b.ScalaMainClass, buildServerCanDebug: Boolean, + isJVM: Boolean, ): List[l.Command] = { val javaBinary = buildTargets .scalaTarget(target) @@ -300,26 +312,32 @@ final class RunTestCodeLens( JavaBinary.javaBinaryFromPath(scalaTarget.jvmHome) ) .orElse(userConfig().usedJavaBinary) - val (data, shellCommandAdded) = buildTargetClasses.jvmRunEnvironment - .get(target) - .zip(javaBinary) match { - case None => - (main.toJson, false) - case Some((env, javaBinary)) => - (ExtendedScalaMainClass(main, env, javaBinary, workspace).toJson, true) - } + val (data, shellCommandAdded) = + if (!isJVM) (main.toJson, false) + else + buildTargetClasses.jvmRunEnvironment + .get(target) + .zip(javaBinary) match { + case None => + (main.toJson, false) + case Some((env, javaBinary)) => + ( + ExtendedScalaMainClass(main, env, javaBinary, workspace).toJson, + true, + ) + } val params = { val dataKind = b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS sessionParams(target, dataKind, data) } - if (clientConfig.isDebuggingProvider() && buildServerCanDebug) + if (clientConfig.isDebuggingProvider() && buildServerCanDebug && isJVM) List( command("run", StartRunSession, params), command("debug", StartDebugSession, params), ) - // run provider needs shell command to run currently, we don't support pure run inside metals - else if (shellCommandAdded && clientConfig.isRunProvider()) + // run provider needs shell command to run currently, we don't support pure run inside metals for JVM + else if ((shellCommandAdded || !isJVM) && clientConfig.isRunProvider()) List(command("run", StartRunSession, params)) else Nil } diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProtocol.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProtocol.scala index 6538725a44f..45d98dbf89c 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProtocol.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProtocol.scala @@ -97,6 +97,15 @@ object DebugProtocol { response } + object EmptyResponse { + def apply(initialize: DebugRequestMessage): DebugResponseMessage = { + val response = new DebugResponseMessage + response.setId(initialize.getId()) + response.setMethod(initialize.getMethod()) + response + } + } + object SyntheticMessage { def unapply(msg: IdentifiableMessage): Option[IdentifiableMessage] = { if (msg.getId == null) Some(msg) @@ -135,6 +144,28 @@ object DebugProtocol { } } + object ConfigurationDone { + def unapply( + request: DebugRequestMessage + ): Option[Option[dap.ConfigurationDoneArguments]] = { + if (request.getMethod != "configurationDone") None + else + Some( + parse[dap.ConfigurationDoneArguments](request.getParams).toOption + ) + } + } + + object TerminateRequest { + def unapply( + request: DebugRequestMessage + ): Option[dap.TerminateArguments] = { + if (request.getMethod != "terminate") None + else + parse[dap.TerminateArguments](request.getParams).toOption + } + } + object SetBreakpointRequest { def unapply(request: RequestMessage): Option[SetBreakpointsArguments] = { if (request.getMethod != "setBreakpoints") None @@ -158,15 +189,14 @@ object DebugProtocol { } } - object RestartRequest { - def unapply(request: RequestMessage): Option[DisconnectArguments] = { - if (request.getMethod != "disconnect") None + object DisconnectRequest { + def unapply(request: DebugRequestMessage): Option[DisconnectArguments] = { + if (request.getMethod != DisconnectRequest.name) None else { - parse[DisconnectArguments](request.getParams) - .filter(_.getRestart) - .toOption + parse[DisconnectArguments](request.getParams).toOption } } + val name = "disconnect" } object OutputNotification { diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala index 2022b56bef6..4eeb2d8e057 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala @@ -42,6 +42,7 @@ import scala.meta.internal.metals.SourceMapper import scala.meta.internal.metals.StacktraceAnalyzer import scala.meta.internal.metals.StatusBar import scala.meta.internal.metals.UserConfiguration +import scala.meta.internal.metals.clients.language.LogForwarder import scala.meta.internal.metals.clients.language.MetalsLanguageClient import scala.meta.internal.metals.clients.language.MetalsQuickPickItem import scala.meta.internal.metals.clients.language.MetalsQuickPickParams @@ -87,13 +88,30 @@ class DebugProvider( sourceMapper: SourceMapper, userConfig: () => UserConfiguration, testProvider: TestSuitesProvider, -) extends Cancelable { +) extends Cancelable + with LogForwarder { import DebugProvider._ private val debugSessions = new MutableCancelable() - override def cancel(): Unit = debugSessions.cancel() + private val currentRunner = + new ju.concurrent.atomic.AtomicReference[DebugRunner](null) + + override def info(message: String): Unit = { + val runner = currentRunner.get() + if (runner != null) runner.stdout(message) + } + + override def error(message: String): Unit = { + val runner = currentRunner.get() + if (runner != null) runner.error(message) + } + + override def cancel(): Unit = { + Option(currentRunner.get()).foreach(_.cancel()) + debugSessions.cancel() + } lazy val buildTargetClassesFinder = new BuildTargetClassesFinder( buildTargets, @@ -102,9 +120,9 @@ class DebugProvider( ) def start( - parameters: b.DebugSessionParams, - cancelPromise: Promise[Unit], + parameters: b.DebugSessionParams )(implicit ec: ExecutionContext): Future[DebugServer] = { + val cancelPromise = Promise[Unit]() for { sessionName <- Future.fromTry(parseSessionName(parameters)) jvmOptionsTranslatedParams = translateJvmParams(parameters) @@ -112,23 +130,86 @@ class DebugProvider( .fold[Future[BuildServerConnection]](BuildServerUnavailableError)( Future.successful ) - debugServer <- start( + isJvm = parameters + .getTargets() + .asScala + .flatMap(buildTargets.scalaTarget) + .forall( + _.scalaInfo.getPlatform == b.ScalaPlatform.JVM + ) + debugServer <- + if (isJvm) + statusBar.trackSlowFuture( + "Starting debug server", + start( + sessionName, + jvmOptionsTranslatedParams, + buildServer, + cancelPromise, + ), + () => cancelPromise.trySuccess(()), + ) + else + runLocally( + sessionName, + jvmOptionsTranslatedParams, + buildServer, + cancelPromise, + ) + } yield debugServer + } + + private def runLocally( + sessionName: String, + parameters: b.DebugSessionParams, + buildServer: BuildServerConnection, + cancelPromise: Promise[Unit], + )(implicit ec: ExecutionContext): Future[DebugServer] = { + val proxyServer = new ServerSocket(0, 50, localAddress) + val host = InetAddresses.toUriString(proxyServer.getInetAddress) + val port = proxyServer.getLocalPort + proxyServer.setSoTimeout(10 * 1000) + val uri = URI.create(s"tcp://$host:$port") + + val awaitClient = () => Future(proxyServer.accept()) + + DebugRunner + .open( sessionName, - jvmOptionsTranslatedParams, - buildServer, + awaitClient, + stacktraceAnalyzer, + () => { + val runParams = new b.RunParams(parameters.getTargets().asScala.head) + buildServer.buildTargetRun(runParams, cancelPromise) + }, cancelPromise, ) - } yield debugServer + .flatMap { runner => + currentRunner.set(runner) + runner.listen.map { code => + currentRunner.set(null) + code + } + } + + val server = new DebugServer( + sessionName, + uri, + () => Future.failed(new RuntimeException("No server connected")), + ) + + Future.successful(server) } + private def localAddress: InetAddress = InetAddress.getByName("127.0.0.1") + private def start( sessionName: String, parameters: b.DebugSessionParams, buildServer: BuildServerConnection, cancelPromise: Promise[Unit], )(implicit ec: ExecutionContext): Future[DebugServer] = { - val inetAddress = InetAddress.getByName("127.0.0.1") - val proxyServer = new ServerSocket(0, 50, inetAddress) + val proxyServer = new ServerSocket(0, 50, localAddress) val host = InetAddresses.toUriString(proxyServer.getInetAddress) val port = proxyServer.getLocalPort proxyServer.setSoTimeout(10 * 1000) @@ -395,13 +476,8 @@ class DebugProvider( def asSession( debugParams: DebugSessionParams )(implicit ec: ExecutionContext): Future[DebugSession] = { - val cancelPromise = Promise[Unit]() for { - server <- statusBar.trackSlowFuture( - "Starting debug server", - start(debugParams, cancelPromise), - () => cancelPromise.trySuccess(()), - ) + server <- start(debugParams), } yield { statusBar.addMessage("Started debug server!") DebugSession(server.sessionName, server.uri.toString) diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProxy.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProxy.scala index 47a8baf1909..9bbb77f2238 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProxy.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProxy.scala @@ -19,11 +19,11 @@ import scala.meta.internal.metals.StacktraceAnalyzer import scala.meta.internal.metals.StatusBar import scala.meta.internal.metals.Trace import scala.meta.internal.metals.debug.DebugProtocol.CompletionRequest +import scala.meta.internal.metals.debug.DebugProtocol.DisconnectRequest import scala.meta.internal.metals.debug.DebugProtocol.ErrorOutputNotification import scala.meta.internal.metals.debug.DebugProtocol.InitializeRequest import scala.meta.internal.metals.debug.DebugProtocol.LaunchRequest import scala.meta.internal.metals.debug.DebugProtocol.OutputNotification -import scala.meta.internal.metals.debug.DebugProtocol.RestartRequest import scala.meta.internal.metals.debug.DebugProtocol.SetBreakpointRequest import scala.meta.internal.metals.debug.DebugProxy._ import scala.meta.io.AbsolutePath @@ -91,7 +91,7 @@ private[debug] final class DebugProxy( case request @ LaunchRequest(debugMode) => this.debugMode = debugMode server.send(request) - case request @ RestartRequest(_) => + case request @ DisconnectRequest(args) if args.getRestart() => initialized.trySuccess(()) // set the status first, since the server can kill the connection exitStatus.trySuccess(Restarted) diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugRunner.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugRunner.scala new file mode 100644 index 00000000000..8e4857decf0 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugRunner.scala @@ -0,0 +1,179 @@ +package scala.meta.internal.metals.debug + +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.util.Try + +import scala.meta.internal.metals.Cancelable +import scala.meta.internal.metals.StacktraceAnalyzer + +import ch.epfl.scala.bsp4j.RunResult +import org.eclipse.lsp4j.debug.Capabilities +import org.eclipse.lsp4j.debug.ExitedEventArguments +import org.eclipse.lsp4j.debug.OutputEventArguments +import org.eclipse.lsp4j.debug.OutputEventArgumentsCategory +import org.eclipse.lsp4j.jsonrpc.MessageConsumer +import org.eclipse.lsp4j.jsonrpc.debug.messages.DebugNotificationMessage +import org.eclipse.lsp4j.jsonrpc.messages.IdentifiableMessage +import org.eclipse.lsp4j.jsonrpc.messages.Message + +/** + * Runner used for basic DAP run requests without the overhead of + * debugging. This onlt runs the program and responds very basic messages and + * events to the DAP client. + */ +class DebugRunner( + sessionName: String, + client: RemoteEndpoint, + stackTraceAnalyzer: StacktraceAnalyzer, + runFuture: () => Future[RunResult], + cancelling: Promise[Unit], +)(implicit ec: ExecutionContext) { + + val lastId = new AtomicInteger(DebugProtocol.FirstMessageId) + private val clientReady: Promise[Unit] = Promise[Unit]() + private val exitStatus: Future[DebugProxy.ExitStatus] = + for { + _ <- clientReady.future + res <- runFuture() + } yield { + val exited = new DebugNotificationMessage + exited.setMethod("exited") + val exitedArgs = new ExitedEventArguments() + exitedArgs.setExitCode(res.getStatusCode().getValue()) + exited.setParams(exitedArgs) + exited.setId(lastId.incrementAndGet()) + client.consume(exited) + + val terminated = new DebugNotificationMessage + terminated.setMethod("terminated") + terminated.setId(lastId.incrementAndGet()) + client.consume(terminated) + DebugProxy.Terminated + } + + private val cancelled = new AtomicBoolean() + private def listenToClient(): Unit = { + Future(client.listen(handleClientMessage)).andThen { case _ => cancel() } + } + + lazy val listen: Future[DebugProxy.ExitStatus] = { + scribe.info(s"Starting debug runner for [$sessionName]") + listenToClient() + exitStatus.map { st => + client.cancel() + st + } + } + + private val handleClientMessage: MessageConsumer = { message => + setIdFromMessage(message) + message match { + case null => + () // ignore + case _ if cancelled.get() => + () // ignore + case request @ DebugProtocol.InitializeRequest(_) => + val response = DebugProtocol.EmptyResponse(request) + response.setResult(new Capabilities) + client.consume(response) + + case request @ DebugProtocol.LaunchRequest(_) => + val response = DebugProtocol.EmptyResponse(request) + clientReady.trySuccess(()) + client.consume(response) + + case request @ DebugProtocol.ConfigurationDone(_) => + val response = DebugProtocol.EmptyResponse(request) + client.consume(response) + + case request @ DebugProtocol.DisconnectRequest(_) => + val response = DebugProtocol.EmptyResponse(request) + clientReady.trySuccess(()) + client.consume(response) + + case request @ DebugProtocol.TerminateRequest(_) => + val response = DebugProtocol.EmptyResponse(request) + clientReady.trySuccess(()) + cancelling.trySuccess(()) + client.consume(response) + + case message => + scribe.debug("Message not handled:\n" + message) + } + } + + def setIdFromMessage(msg: Message): Unit = { + msg match { + case message: IdentifiableMessage => + Try(message.getId().toInt).foreach(lastId.set) + case _ => + } + } + + def stdout(message: String): Unit = + output(message, OutputEventArgumentsCategory.STDOUT)(_ => None) + + def error(message: String): Unit = { + output(message, OutputEventArgumentsCategory.STDERR) { notification => + stackTraceAnalyzer + .fileLocationFromLine(message) + .map(DebugProtocol.stacktraceOutputResponse(notification, _)) + + } + } + + private def output(message: String, category: String)( + withDetails: OutputEventArguments => Option[DebugNotificationMessage] + ) = { + val output = new OutputEventArguments() + output.setCategory(category) + output.setOutput(message + "\n") + + def default = { + val notification = new DebugNotificationMessage() + notification.setMethod("output") + notification.setParams(output) + notification + } + + val notification = withDetails(output).getOrElse(default) + notification.setId(lastId.incrementAndGet()) + client.consume(notification) + } + + def cancel(): Unit = { + if (cancelled.compareAndSet(false, true)) { + scribe.info(s"Canceling run for [$sessionName]") + Cancelable.cancelAll(List(client)) + } + } +} + +object DebugRunner { + + def open( + name: String, + awaitClient: () => Future[Socket], + stackTraceAnalyzer: StacktraceAnalyzer, + runFuture: () => Future[RunResult], + cancelling: Promise[Unit], + )(implicit ec: ExecutionContext): Future[DebugRunner] = { + for { + client <- awaitClient() + .map(new SocketEndpoint(_)) + .map(new MessageIdAdapter(_)) + } yield new DebugRunner( + name, + client, + stackTraceAnalyzer, + runFuture, + cancelling, + ) + } +} diff --git a/tests/slow/src/test/scala/tests/scalacli/ScalaCliSuite.scala b/tests/slow/src/test/scala/tests/scalacli/ScalaCliSuite.scala index 640fbbf1b5d..4ea8cfd69f9 100644 --- a/tests/slow/src/test/scala/tests/scalacli/ScalaCliSuite.scala +++ b/tests/slow/src/test/scala/tests/scalacli/ScalaCliSuite.scala @@ -14,6 +14,9 @@ import scala.meta.internal.metals.{BuildInfo => V} import org.eclipse.{lsp4j => l} import tests.FileLayout +import scala.meta.internal.metals.debug.TestDebugger +import scala.meta.internal.metals.DebugUnresolvedMainClassParams +import scala.meta.internal.metals.JsonParser._ class ScalaCliSuite extends BaseScalaCliSuite(V.scala3) { override def serverConfig: MetalsServerConfig = @@ -24,6 +27,8 @@ class ScalaCliSuite extends BaseScalaCliSuite(V.scala3) { InitializationOptions.Default.copy( inlineDecorationProvider = Some(true), decorationProvider = Some(true), + debuggingProvider = Option(true), + runProvider = Option(true), ) ) @@ -546,4 +551,118 @@ class ScalaCliSuite extends BaseScalaCliSuite(V.scala3) { } yield () } + def startDebugging( + main: String, + buildTarget: String, + ): Future[TestDebugger] = { + server.startDebuggingUnresolved( + new DebugUnresolvedMainClassParams(main, buildTarget).toJson + ) + } + + test("base-native-run") { + cleanWorkspace() + for { + _ <- scalaCliInitialize(useBsp = true)( + s"""/MyMain.scala + |//> using scala "$scalaVersion" + |//> using platform "native" + | + |import scala.scalanative._ + | + |object MyMain { + | def main(args: Array[String]): Unit = { + | println("Hello world!") + | System.exit(0) + | } + |} + | + |""".stripMargin + ) + _ <- server.didOpen("MyMain.scala") + textWithLenses <- server.codeLensesText( + "MyMain.scala", + printCommand = false, + )(maxRetries = 5) + _ = assertNoDiff( + textWithLenses, + s"""|//> using scala "$scalaVersion" + |//> using platform "native" + | + |import scala.scalanative._ + | + |<> + |object MyMain { + | def main(args: Array[String]): Unit = { + | println("Hello world!") + | System.exit(0) + | } + |} + | + |""".stripMargin, + ) + targets <- server.listBuildTargets + mainTarget = targets.find(!_.contains("test")) + _ = assert(mainTarget.isDefined, "No main target specified") + debugServer <- startDebugging("MyMain", mainTarget.get) + _ <- debugServer.initialize + _ <- debugServer.launch + _ <- debugServer.configurationDone + _ <- debugServer.shutdown + output <- debugServer.allOutput + } yield assertContains(output, "Hello world!\n") + } + + test("base-js-run") { + cleanWorkspace() + for { + _ <- scalaCliInitialize(useBsp = true)( + s"""/MyMain.scala + |//> using scala "$scalaVersion" + |//> using platform "js" + | + |import scala.scalajs.js + | + |object MyMain { + | def main(args: Array[String]): Unit = { + | println("Hello world!") + | // System.exit(0) + | } + |} + | + |""".stripMargin + ) + _ <- server.didOpen("MyMain.scala") + textWithLenses <- server.codeLensesText( + "MyMain.scala", + printCommand = false, + )(maxRetries = 5) + _ = assertNoDiff( + textWithLenses, + s"""|//> using scala "$scalaVersion" + |//> using platform "js" + | + |import scala.scalajs.js + | + |<> + |object MyMain { + | def main(args: Array[String]): Unit = { + | println("Hello world!") + | // System.exit(0) + | } + |} + | + |""".stripMargin, + ) + targets <- server.listBuildTargets + mainTarget = targets.find(!_.contains("test")) + _ = assert(mainTarget.isDefined, "No main target specified") + debugServer <- startDebugging("MyMain", mainTarget.get) + _ <- debugServer.initialize + _ <- debugServer.launch + _ <- debugServer.configurationDone + _ <- debugServer.shutdown + output <- debugServer.allOutput + } yield assertContains(output, "Hello world!\n") + } } diff --git a/tests/unit/src/main/scala/scala/meta/internal/metals/debug/RemoteServer.scala b/tests/unit/src/main/scala/scala/meta/internal/metals/debug/RemoteServer.scala index 5927181443f..de301d0a9b2 100644 --- a/tests/unit/src/main/scala/scala/meta/internal/metals/debug/RemoteServer.scala +++ b/tests/unit/src/main/scala/scala/meta/internal/metals/debug/RemoteServer.scala @@ -16,7 +16,7 @@ import scala.reflect.classTag import scala.meta.internal.metals.Cancelable import scala.meta.internal.metals.JsonParser._ import scala.meta.internal.metals.MetalsEnrichments._ -import scala.meta.internal.metals.debug.DebugProtocol.FirstMessageId +import scala.meta.internal.metals.debug.DebugProtocol import com.google.gson.JsonElement import org.eclipse.lsp4j.debug.Capabilities @@ -40,7 +40,7 @@ private[debug] final class RemoteServer( private val remote = new SocketEndpoint(socket) private val ongoing = new TrieMap[String, Response => Unit]() - private val id = new AtomicInteger(FirstMessageId) + private val id = new AtomicInteger(DebugProtocol.FirstMessageId) lazy val listening: Future[Unit] = Future(listen()) override def initialize( @@ -124,7 +124,7 @@ private[debug] final class RemoteServer( override def disconnect( args: DisconnectArguments ): CompletableFuture[Void] = { - sendRequest("disconnect", args) + sendRequest(DebugProtocol.DisconnectRequest.name, args) } private def listen(): Unit = { diff --git a/tests/unit/src/main/scala/tests/BaseCodeLensLspSuite.scala b/tests/unit/src/main/scala/tests/BaseCodeLensLspSuite.scala index 845e28fc418..9b15a5574eb 100644 --- a/tests/unit/src/main/scala/tests/BaseCodeLensLspSuite.scala +++ b/tests/unit/src/main/scala/tests/BaseCodeLensLspSuite.scala @@ -110,6 +110,7 @@ abstract class BaseCodeLensLspSuite( printCommand: Boolean = false, extraInitialization: (TestingServer, String) => Future[Unit] = (_, _) => Future.unit, + minExpectedLenses: Int = 1, )( expected: => String )(implicit loc: Location): Unit = { @@ -144,7 +145,12 @@ abstract class BaseCodeLensLspSuite( |""".stripMargin ) _ <- extraInitialization(server, sourceFile) - _ <- assertCodeLenses(sourceFile, expected, printCommand = printCommand) + _ <- assertCodeLenses( + sourceFile, + expected, + printCommand = printCommand, + minExpectedLenses = minExpectedLenses, + ) } yield () } } @@ -154,6 +160,7 @@ abstract class BaseCodeLensLspSuite( library: Option[String] = None, scalaVersion: Option[String] = None, printCommand: Boolean = false, + minExpectedLenses: Int = 1, )( expected: => String )(implicit loc: Location): Unit = check( @@ -163,6 +170,7 @@ abstract class BaseCodeLensLspSuite( printCommand, (server, sourceFile) => server.discoverTestSuites(List(sourceFile)).map(_ => ()), + minExpectedLenses, )(expected) protected def assertCodeLenses( @@ -170,12 +178,16 @@ abstract class BaseCodeLensLspSuite( expected: String, maxRetries: Int = 4, printCommand: Boolean = false, + minExpectedLenses: Int = 1, )(implicit loc: Location): Future[Unit] = { val obtained = - server.codeLensesText(relativeFile, printCommand)(maxRetries).recover { - case _: NoSuchElementException => + server + .codeLensesText(relativeFile, printCommand, minExpectedLenses)( + maxRetries + ) + .recover { case _: NoSuchElementException => server.textContents(relativeFile) - } + } obtained.map(assertNoDiff(_, expected)) } diff --git a/tests/unit/src/main/scala/tests/TestingClient.scala b/tests/unit/src/main/scala/tests/TestingClient.scala index 00beb8840ff..3f198043cc0 100644 --- a/tests/unit/src/main/scala/tests/TestingClient.scala +++ b/tests/unit/src/main/scala/tests/TestingClient.scala @@ -55,6 +55,7 @@ import org.eclipse.lsp4j.WorkspaceEdit import org.eclipse.lsp4j.jsonrpc.CompletableFutures import tests.MetalsTestEnrichments._ import tests.TestOrderings._ +import scala.meta.internal.metals.Debug /** * Fake LSP client that responds to notifications/requests initiated by the server. @@ -131,6 +132,7 @@ class TestingClient(workspace: AbsolutePath, val buffers: Buffers) override def metalsExecuteClientCommand( params: ExecuteCommandParams ): Unit = { + Debug.printEnclosing(params.getCommand()) clientCommands.addLast(params) params.getCommand match { case ClientCommands.RefreshModel.id => diff --git a/tests/unit/src/main/scala/tests/TestingServer.scala b/tests/unit/src/main/scala/tests/TestingServer.scala index ed664d0c72f..2a3513728c5 100644 --- a/tests/unit/src/main/scala/tests/TestingServer.scala +++ b/tests/unit/src/main/scala/tests/TestingServer.scala @@ -20,6 +20,8 @@ import scala.collection.mutable.ListBuffer import scala.concurrent.ExecutionContextExecutorService import scala.concurrent.Future import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Success import scala.util.matching.Regex import scala.{meta => m} @@ -635,6 +637,18 @@ final case class TestingServer( fullServer.executeCommand(command.toExecuteCommandParams()).asScala } + def listBuildTargets: Future[List[String]] = { + for { + targetsArray <- executeCommand(ServerCommands.ListBuildTargets) + } yield targetsArray.toJson.as[Array[String]] match { + case Failure(exception) => + scribe.error("Could not read build targets", exception) + Nil + case Success(targets) => + targets.toList + } + } + /** * Operating on strings can be dangerous, but needed for running unknown commands * and for the StartDebugAdapter command, which doesn't have a stable argument. @@ -1025,11 +1039,15 @@ final case class TestingServer( } yield classes } - def codeLensesText(filename: String, printCommand: Boolean = false)( + def codeLensesText( + filename: String, + printCommand: Boolean = false, + minExpectedLenses: Int = 1, + )( maxRetries: Int ): Future[String] = { for { - lenses <- codeLenses(filename, maxRetries) + lenses <- codeLenses(filename, maxRetries, minExpectedLenses) textEdits = CodeLensesTextEdits(lenses, printCommand) } yield TextEdits.applyEdits(textContents(filename), textEdits.toList) } @@ -1037,6 +1055,7 @@ final case class TestingServer( def codeLenses( filename: String, maxRetries: Int = 4, + minExpectedLenses: Int = 1, ): Future[List[l.CodeLens]] = { Debug.printEnclosing(filename) val path = toPath(filename) @@ -1052,21 +1071,24 @@ final case class TestingServer( // or fails if it could nat be achieved withing [[maxRetries]] number of tries var retries = maxRetries val codeLenses = Promise[List[l.CodeLens]]() + def getLenses = fullServer + .codeLens(params) + .asScala + .map(_.asScala) + .withTimeout(10, util.concurrent.TimeUnit.SECONDS) + .recover { _ => + scribe.info(s"Timeout for fetching lenses reached for $filename") + Nil + } + val handler = { refreshCount: Int => scribe.info(s"Refreshing model for $filename") if (refreshCount > 0) for { - lenses <- fullServer - .codeLens(params) - .asScala - .map(_.asScala) - .withTimeout(10, util.concurrent.TimeUnit.SECONDS) - .recover { _ => - scribe.info(s"Timeout for fetching lenses reached for $filename") - Nil - } + lenses <- getLenses } { - if (lenses.nonEmpty) codeLenses.trySuccess(lenses.toList) + if (lenses.size >= minExpectedLenses) + codeLenses.trySuccess(lenses.toList) else if (retries > 0) { retries -= 1 server.compilations.compileFile(path) @@ -1085,9 +1107,13 @@ final case class TestingServer( _ = client.refreshModelHandler = handler // first compilation, to trigger the handler _ <- server.compilations.compileFile(path) - lenses <- codeLenses.future + lenses <- getLenses + .flatMap { lenses => + if (lenses.size >= minExpectedLenses) Future.successful(lenses) + else codeLenses.future + } .withTimeout(60, util.concurrent.TimeUnit.SECONDS) - } yield lenses + } yield lenses.toList } def formatCompletion( diff --git a/tests/unit/src/test/scala/tests/CodeLensLspSuite.scala b/tests/unit/src/test/scala/tests/CodeLensLspSuite.scala index 9b03c3b8c6e..c18620a7ec2 100644 --- a/tests/unit/src/test/scala/tests/CodeLensLspSuite.scala +++ b/tests/unit/src/test/scala/tests/CodeLensLspSuite.scala @@ -53,6 +53,7 @@ class CodeLensLspSuite extends BaseCodeLensLspSuite("codeLenses") { checkTestCases( "test-suite-with-tests", library = Some("org.scalatest::scalatest:3.2.16"), + minExpectedLenses = 6, )( """|package foo.bar |<><>