Skip to content

Commit

Permalink
feature: Add support for running Native
Browse files Browse the repository at this point in the history
Works by doing the bare minimum DAP requires and forwars the logs from the build server
  • Loading branch information
tgodzik committed Dec 27, 2023
1 parent d596f64 commit 85cd09c
Show file tree
Hide file tree
Showing 14 changed files with 588 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,17 @@ class BuildServerConnection private (
}
}

def buildTargetRun(
params: RunParams,
cancelPromise: Promise[Unit],
): Future[RunResult] = {
val completableFuture = register(server => server.buildTargetRun(params))
cancelPromise.future.foreach { _ =>
completableFuture.cancel(true)
}
completableFuture.asScala
}

def buildTargetScalacOptions(
params: ScalacOptionsParams
): Future[ScalacOptionsResult] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package scala.meta.internal.metals.clients.language

import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
import java.{util => ju}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.concurrent.TrieMap
import scala.concurrent.Promise
Expand All @@ -23,12 +23,24 @@ import scala.meta.internal.metals.Timer
import scala.meta.internal.tvp._
import scala.meta.io.AbsolutePath

import ch.epfl.scala.bsp4j._
import ch.epfl.scala.{bsp4j => b}
import com.google.gson.JsonObject
import org.eclipse.lsp4j.jsonrpc.services.JsonNotification
import org.eclipse.{lsp4j => l}

/**
* Used to forward messages from the build server. Messages might
* be mixed if the server is sending messages as well as output from
* running. This hasn't been a problem yet, not perfect solution,
* but seems to work ok.
*/
trait LogForwarder {
def error(message: String): Unit = ()
def warn(message: String): Unit = ()
def info(message: String): Unit = ()
def log(message: String): Unit = ()
}

/**
* A build client that forwards notifications from the build server to the language client.
*/
Expand All @@ -39,31 +51,39 @@ final class ForwardingMetalsBuildClient(
clientConfig: ClientConfiguration,
statusBar: StatusBar,
time: Time,
didCompile: CompileReport => Unit,
onBuildTargetDidCompile: BuildTargetIdentifier => Unit,
didCompile: b.CompileReport => Unit,
onBuildTargetDidCompile: b.BuildTargetIdentifier => Unit,
onBuildTargetDidChangeFunc: b.DidChangeBuildTarget => Unit,
bspErrorHandler: BspErrorHandler,
) extends MetalsBuildClient
with Cancelable {

private val forwarders =
new AtomicReference(List.empty[LogForwarder])

def registerLogForwarder(
logForwarder: LogForwarder
): List[LogForwarder] = {
forwarders.getAndUpdate(_.prepended(logForwarder))
}
private case class Compilation(
timer: Timer,
promise: Promise[CompileReport],
promise: Promise[b.CompileReport],
isNoOp: Boolean,
progress: TaskProgress = TaskProgress.empty,
) extends TreeViewCompilation {
def progressPercentage = progress.percentage
}

private val compilations = TrieMap.empty[BuildTargetIdentifier, Compilation]
private val compilations = TrieMap.empty[b.BuildTargetIdentifier, Compilation]
private val hasReportedError = Collections.newSetFromMap(
new ConcurrentHashMap[BuildTargetIdentifier, java.lang.Boolean]()
new ConcurrentHashMap[b.BuildTargetIdentifier, java.lang.Boolean]()
)

val updatedTreeViews: ju.Set[BuildTargetIdentifier] =
ConcurrentHashSet.empty[BuildTargetIdentifier]
val updatedTreeViews: java.util.Set[b.BuildTargetIdentifier] =
ConcurrentHashSet.empty[b.BuildTargetIdentifier]

def buildHasErrors(buildTargetId: BuildTargetIdentifier): Boolean = {
def buildHasErrors(buildTargetId: b.BuildTargetIdentifier): Boolean = {
buildTargets
.buildTargetTransitiveDependencies(buildTargetId)
.exists(hasReportedError.contains(_))
Expand Down Expand Up @@ -96,18 +116,22 @@ final class ForwardingMetalsBuildClient(
def onBuildShowMessage(params: l.MessageParams): Unit =
languageClient.showMessage(params)

def onBuildLogMessage(params: l.MessageParams): Unit =
def onBuildLogMessage(params: l.MessageParams): Unit = {
params.getType match {
case l.MessageType.Error =>
bspErrorHandler.onError(params.getMessage())
forwarders.get().foreach(_.error(params.getMessage()))
case l.MessageType.Warning =>
forwarders.get().foreach(_.warn(params.getMessage()))
scribe.warn(params.getMessage)
case l.MessageType.Info =>
forwarders.get().foreach(_.info(params.getMessage()))
scribe.info(params.getMessage)
case l.MessageType.Log =>
forwarders.get().foreach(_.log(params.getMessage()))
scribe.info(params.getMessage)
}

}
def onBuildPublishDiagnostics(params: b.PublishDiagnosticsParams): Unit = {
diagnostics.onBuildPublishDiagnostics(params)
}
Expand All @@ -119,9 +143,9 @@ final class ForwardingMetalsBuildClient(
def onBuildTargetCompileReport(params: b.CompileReport): Unit = {}

@JsonNotification("build/taskStart")
def buildTaskStart(params: TaskStartParams): Unit = {
def buildTaskStart(params: b.TaskStartParams): Unit = {
params.getDataKind match {
case TaskStartDataKind.COMPILE_TASK =>
case b.TaskStartDataKind.COMPILE_TASK =>
if (
params.getMessage != null && params.getMessage.startsWith("Compiling")
) {
Expand All @@ -137,7 +161,7 @@ final class ForwardingMetalsBuildClient(
compilations.remove(target).foreach(_.promise.cancel())

val name = info.getDisplayName
val promise = Promise[CompileReport]()
val promise = Promise[b.CompileReport]()
val isNoOp =
params.getMessage != null && params.getMessage.startsWith(
"Start no-op compilation"
Expand All @@ -157,9 +181,9 @@ final class ForwardingMetalsBuildClient(
}

@JsonNotification("build/taskFinish")
def buildTaskFinish(params: TaskFinishParams): Unit = {
def buildTaskFinish(params: b.TaskFinishParams): Unit = {
params.getDataKind match {
case TaskFinishDataKind.COMPILE_REPORT =>
case b.TaskFinishDataKind.COMPILE_REPORT =>
for {
report <- params.asCompileReport
compilation <- compilations.remove(report.getTarget)
Expand Down Expand Up @@ -214,8 +238,8 @@ final class ForwardingMetalsBuildClient(
}

@JsonNotification("build/taskProgress")
def buildTaskProgress(params: TaskProgressParams): Unit = {
def buildTargetFromParams: Option[BuildTargetIdentifier] =
def buildTaskProgress(params: b.TaskProgressParams): Unit = {
def buildTargetFromParams: Option[b.BuildTargetIdentifier] =
for {
data <- Option(params.getData).collect { case o: JsonObject =>
o
Expand All @@ -227,7 +251,7 @@ final class ForwardingMetalsBuildClient(
if uriElement.isJsonPrimitive
uri = uriElement.getAsJsonPrimitive
if uri.isString
} yield new BuildTargetIdentifier(uri.getAsString)
} yield new b.BuildTargetIdentifier(uri.getAsString)

params.getDataKind match {
case "bloop-progress" =>
Expand All @@ -252,7 +276,7 @@ final class ForwardingMetalsBuildClient(

def ongoingCompilations(): TreeViewCompilations =
new TreeViewCompilations {
override def get(id: BuildTargetIdentifier) = compilations.get(id)
override def get(id: b.BuildTargetIdentifier) = compilations.get(id)
override def isEmpty = compilations.isEmpty
override def size = compilations.size
override def buildTargets = compilations.keysIterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ class MetalsLspService(
testProvider,
)
)
buildClient.registerLogForwarder(debugProvider)

private val scalafixProvider: ScalafixProvider = ScalafixProvider(
buffers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ final class RunTestCodeLens(
val lenses = for {
buildTargetId <- buildTargets.inverseSources(path)
buildTarget <- buildTargets.info(buildTargetId)
// generate code lenses only for JVM based targets for Scala
if buildTarget.asScalaBuildTarget.forall(
isJVM = buildTarget.asScalaBuildTarget.forall(
_.getPlatform == b.ScalaPlatform.JVM
)
connection <- buildTargets.buildServerOf(buildTargetId)
Expand All @@ -85,6 +84,7 @@ final class RunTestCodeLens(
classes,
distance,
buildServerCanDebug,
isJVM,
)
} else if (buildServerCanDebug || clientConfig.isRunProvider()) {
codeLenses(
Expand All @@ -94,6 +94,7 @@ final class RunTestCodeLens(
distance,
path,
buildServerCanDebug,
isJVM,
)
} else { Nil }

Expand Down Expand Up @@ -160,6 +161,7 @@ final class RunTestCodeLens(
Nil.asJava,
),
buildServerCanDebug,
isJVM = true,
)
else
Nil
Expand All @@ -177,6 +179,7 @@ final class RunTestCodeLens(
distance: TokenEditDistance,
path: AbsolutePath,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): Seq[l.CodeLens] = {
for {
occurrence <- textDocument.occurrences
Expand All @@ -185,19 +188,19 @@ final class RunTestCodeLens(
commands = {
val main = classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug))
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.getOrElse(Nil)
val tests =
// Currently tests can only be run via DAP
if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
testClasses(target, classes, symbol)
testClasses(target, classes, symbol, isJVM)
else Nil
val fromAnnot = DebugProvider
.mainFromAnnotation(occurrence, textDocument)
.flatMap { symbol =>
classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug))
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
}
.getOrElse(Nil)
val javaMains =
Expand Down Expand Up @@ -225,6 +228,7 @@ final class RunTestCodeLens(
classes: BuildTargetClasses.Classes,
distance: TokenEditDistance,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): Seq[l.CodeLens] = {
val scriptFileName = textDocument.uri.stripSuffix(".sc")

Expand All @@ -234,15 +238,15 @@ final class RunTestCodeLens(
val main =
classes.mainClasses
.get(expectedMainClass)
.map(mainCommand(target, _, buildServerCanDebug))
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.getOrElse(Nil)

val fromAnnotations = textDocument.occurrences.flatMap { occ =>
for {
sym <- DebugProvider.mainFromAnnotation(occ, textDocument)
cls <- classes.mainClasses.get(sym)
range <- occurrenceRange(occ, distance)
} yield mainCommand(target, cls, buildServerCanDebug).map { cmd =>
} yield mainCommand(target, cls, buildServerCanDebug, isJVM).map { cmd =>
new l.CodeLens(range, cmd, null)
}
}.flatten
Expand All @@ -262,64 +266,78 @@ final class RunTestCodeLens(
target: BuildTargetIdentifier,
classes: BuildTargetClasses.Classes,
symbol: String,
isJVM: Boolean,
): List[l.Command] =
if (userConfig().testUserInterface == TestUserInterfaceKind.CodeLenses)
classes.testClasses
.get(symbol)
.toList
.flatMap(symbolInfo =>
testCommand(target, symbolInfo.fullyQualifiedName)
testCommand(target, symbolInfo.fullyQualifiedName, isJVM)
)
else
Nil

private def testCommand(
target: b.BuildTargetIdentifier,
className: String,
isJVM: Boolean,
): List[l.Command] = {
val params = {
val dataKind = b.TestParamsDataKind.SCALA_TEST_SUITES
val data = singletonList(className).toJson
sessionParams(target, dataKind, data)
}

List(
command("test", StartRunSession, params),
command("debug test", StartDebugSession, params),
)
if (isJVM)
List(
command("test", StartRunSession, params),
command("debug test", StartDebugSession, params),
)
else
List(
command("test", StartRunSession, params)
)
}

private def mainCommand(
target: b.BuildTargetIdentifier,
main: b.ScalaMainClass,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): List[l.Command] = {
val javaBinary = buildTargets
.scalaTarget(target)
.flatMap(scalaTarget =>
JavaBinary.javaBinaryFromPath(scalaTarget.jvmHome)
)
.orElse(userConfig().usedJavaBinary)
val (data, shellCommandAdded) = buildTargetClasses.jvmRunEnvironment
.get(target)
.zip(javaBinary) match {
case None =>
(main.toJson, false)
case Some((env, javaBinary)) =>
(ExtendedScalaMainClass(main, env, javaBinary, workspace).toJson, true)
}
val (data, shellCommandAdded) =
if (!isJVM) (main.toJson, false)
else
buildTargetClasses.jvmRunEnvironment
.get(target)
.zip(javaBinary) match {
case None =>
(main.toJson, false)
case Some((env, javaBinary)) =>
(
ExtendedScalaMainClass(main, env, javaBinary, workspace).toJson,
true,
)
}
val params = {
val dataKind = b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS
sessionParams(target, dataKind, data)
}

if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
if (clientConfig.isDebuggingProvider() && buildServerCanDebug && isJVM)
List(
command("run", StartRunSession, params),
command("debug", StartDebugSession, params),
)
// run provider needs shell command to run currently, we don't support pure run inside metals
else if (shellCommandAdded && clientConfig.isRunProvider())
// run provider needs shell command to run currently, we don't support pure run inside metals for JVM
else if ((shellCommandAdded || !isJVM) && clientConfig.isRunProvider())
List(command("run", StartRunSession, params))
else Nil
}
Expand Down

0 comments on commit 85cd09c

Please sign in to comment.