Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.x] Remote caching support #7525

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 30 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,26 @@ lazy val utilRelation = (project in file("internal") / "util-relation")
// Persisted caching based on sjson-new
lazy val utilCache = project
.in(file("util-cache"))
.enablePlugins(
ContrabandPlugin,
// we generate JsonCodec only for actionresult.conta
// JsonCodecPlugin,
)
.dependsOn(utilLogging)
.settings(
testedBaseSettings,
name := "Util Cache",
libraryDependencies ++=
Seq(sjsonNewScalaJson.value, sjsonNewMurmurhash.value, scalaReflect.value),
Seq(
sjsonNewCore.value,
sjsonNewScalaJson.value,
sjsonNewMurmurhash.value,
scalaReflect.value
),
Compile / managedSourceDirectories +=
baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / sourceManaged := baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / contrabandFormatsForType := ContrabandConfig.getFormats,
utilMimaSettings,
Test / fork := true,
)
Expand Down Expand Up @@ -645,6 +660,19 @@ lazy val dependencyTreeProj = (project in file("dependency-tree"))
mimaPreviousArtifacts := Set.empty,
)

lazy val remoteCacheProj = (project in file("sbt-remote-cache"))
.dependsOn(sbtProj)
.settings(
sbtPlugin := true,
baseSettings,
name := "sbt-remote-cache",
pluginCrossBuild / sbtVersion := version.value,
publishMavenStyle := true,
// mimaSettings,
mimaPreviousArtifacts := Set.empty,
libraryDependencies += remoteapis,
)

// Implementation and support code for defining actions.
lazy val actionsProj = (project in file("main-actions"))
.dependsOn(
Expand Down Expand Up @@ -1266,6 +1294,7 @@ def allProjects =
utilTracking,
collectionProj,
coreMacrosProj,
remoteCacheProj,
) ++ lowerUtilProjects

// These need to be cross published to 2.12 and 2.13 for Zinc
Expand Down
23 changes: 14 additions & 9 deletions main-actions/src/main/scala/sbt/Tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import testing.{
}

import scala.annotation.tailrec
import scala.util.control.NonFatal
import sbt.internal.util.ManagedLogger
import sbt.util.Logger
import sbt.protocol.testing.TestResult
Expand Down Expand Up @@ -534,15 +535,19 @@ object Tests {
case analysis: Analysis =>
val acs: Seq[xsbti.api.AnalyzedClass] = analysis.apis.internal.values.toVector
acs.flatMap { ac =>
val companions = ac.api
val all =
Seq(companions.classApi: Definition, companions.objectApi: Definition) ++
(companions.classApi.structure.declared.toSeq: Seq[Definition]) ++
(companions.classApi.structure.inherited.toSeq: Seq[Definition]) ++
(companions.objectApi.structure.declared.toSeq: Seq[Definition]) ++
(companions.objectApi.structure.inherited.toSeq: Seq[Definition])

all
try
val companions = ac.api
val all =
Seq(companions.classApi: Definition, companions.objectApi: Definition) ++
(companions.classApi.structure.declared.toSeq: Seq[Definition]) ++
(companions.classApi.structure.inherited.toSeq: Seq[Definition]) ++
(companions.objectApi.structure.declared.toSeq: Seq[Definition]) ++
(companions.objectApi.structure.inherited.toSeq: Seq[Definition])
all
catch
case NonFatal(e) =>
if e.getMessage.startsWith("No companions") then Nil
else throw e
}.toSeq
}
def discover(
Expand Down
2 changes: 2 additions & 0 deletions main-settings/src/main/scala/sbt/Def.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits:
private[sbt] var _cacheStore: ActionCacheStore = InMemoryActionCacheStore()
def cacheStore: ActionCacheStore = _cacheStore
private[sbt] var _outputDirectory: Option[Path] = None
private[sbt] val cacheEventLog: CacheEventLog = CacheEventLog()
def cacheConfiguration: BuildWideCacheConfiguration =
BuildWideCacheConfiguration(
_cacheStore,
_outputDirectory.getOrElse(sys.error("outputDirectory has not been set")),
cacheEventLog,
)

inline def cachedTask[A1: JsonFormat](inline a1: A1): Def.Initialize[Task[A1]] =
Expand Down
10 changes: 10 additions & 0 deletions main/src/main/scala/sbt/Defaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4654,6 +4654,16 @@ trait BuildExtra extends BuildCommon with DefExtra {
scalaBinaryVersion.value
)

/**
* Adds remote cache plugin.
*/
def addRemoteCachePlugin: Setting[Seq[ModuleID]] =
libraryDependencies += sbtPluginExtra(
ModuleID("org.scala-sbt", "sbt-remote-cache", sbtVersion.value),
sbtBinaryVersion.value,
scalaBinaryVersion.value
)

/**
* Adds `dependency` as an sbt plugin for the specific sbt version `sbtVersion` and Scala version `scalaVersion`.
* Typically, use the default values for these versions instead of specifying them explicitly.
Expand Down
12 changes: 11 additions & 1 deletion main/src/main/scala/sbt/Keys.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ package sbt

import java.nio.file.{ Path => NioPath }
import java.io.File
import java.net.URL
import java.net.{ URL, URI }
import lmcoursier.definitions.{ CacheLogger, ModuleMatchers, Reconciliation }
import lmcoursier.{ CoursierConfiguration, FallbackDependency }
import org.apache.ivy.core.module.descriptor.ModuleDescriptor
Expand Down Expand Up @@ -116,6 +116,16 @@ object Keys {
val fullServerHandlers = SettingKey(BasicKeys.fullServerHandlers)
val serverHandlers = settingKey[Seq[ServerHandler]]("User-defined server handlers.")
val cacheStores = settingKey[Seq[ActionCacheStore]]("Cache backends")
@cacheLevel(include = Array.empty)
val remoteCache = settingKey[Option[URI]]("URI of the remote cache")
@cacheLevel(include = Array.empty)
val remoteCacheTlsCertificate = settingKey[Option[File]]("Path to a TLS certificate (*.crt) that is trusted to sign server certificates")
@cacheLevel(include = Array.empty)
val remoteCacheTlsClientCertificate = settingKey[Option[File]]("Path to a TLS client certificate *.crt used with remoteCacheTlsClientKey ")
@cacheLevel(include = Array.empty)
val remoteCacheTlsClientKey = settingKey[Option[File]]("Path to a TLS client key *.pem used with remoteCacheTlsClientCertificate")
@cacheLevel(include = Array.empty)
val remoteCacheHeaders = settingKey[Seq[String]]("List of key=value headers to be sent to the remote cache.")
val rootOutputDirectory = SettingKey(BasicKeys.rootOutputDirectory)

// val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
Expand Down
5 changes: 5 additions & 0 deletions main/src/main/scala/sbt/RemoteCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ object RemoteCache {
DiskActionCacheStore(localCacheDirectory.value.toPath(), fileConverter.value)
)
},
remoteCache := SysProp.remoteCache,
remoteCacheTlsCertificate := SysProp.remoteCacheTlsCertificate,
remoteCacheTlsClientCertificate := SysProp.remoteCacheTlsClientCertificate,
remoteCacheTlsClientKey := SysProp.remoteCacheTlsClientKey,
remoteCacheHeaders := SysProp.remoteCacheHeaders,
adpi2 marked this conversation as resolved.
Show resolved Hide resolved
)

lazy val projectSettings: Seq[Def.Setting[_]] = (Seq(
Expand Down
80 changes: 39 additions & 41 deletions main/src/main/scala/sbt/internal/Aggregation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ object Aggregation {
success: Boolean
)

final case class Complete[T](
final case class Complete[A](
start: Long,
stop: Long,
results: sbt.Result[Seq[KeyValue[T]]],
results: sbt.Result[Seq[KeyValue[A]]],
cacheSummary: String,
state: State
)

Expand Down Expand Up @@ -68,44 +69,43 @@ object Aggregation {
)(implicit display: Show[ScopedKey[_]]): Parser[() => State] =
Command.applyEffect(seqParser(ps))(ts => runTasks(s, ts, DummyTaskMap(Nil), show))

private def showRun[T](complete: Complete[T], show: ShowConfig)(implicit
display: Show[ScopedKey[_]]
): Unit = {
import complete._
private def showRun[A](complete: Complete[A], show: ShowConfig)(implicit
display: Show[ScopedKey[?]]
): Unit =
import complete.*
val log = state.log
val extracted = Project.extract(state)
val success = results match
case Result.Value(_) => true
case Result.Inc(_) => false
results.toEither.foreach { r =>
if (show.taskValues) printSettings(r, show.print)
if show.taskValues then printSettings(r, show.print) else ()
}
if (show.success && !state.get(suppressShow).getOrElse(false))
printSuccess(start, stop, extracted, success, log)
}
if show.success && !state.get(suppressShow).getOrElse(false) then
printSuccess(start, stop, extracted, success, cacheSummary, log)
else ()

def timedRun[T](
def timedRun[A](
s: State,
ts: Values[Task[T]],
extra: DummyTaskMap
): Complete[T] = {
ts: Values[Task[A]],
extra: DummyTaskMap,
): Complete[A] =
import EvaluateTask._
import std.TaskExtra._

val extracted = Project extract s
val extracted = Project.extract(s)
import extracted.structure
val toRun = ts.map { case KeyValue(k, t) => t.map(v => KeyValue(k, v)) }.join
val roots = ts.map { case KeyValue(k, _) => k }
val config = extractedTaskConfig(extracted, structure, s)

val start = System.currentTimeMillis
val (newS, result) = withStreams(structure, s) { str =>
val cacheEventLog = Def.cacheConfiguration.cacheEventLog
cacheEventLog.clear()
val (newS, result) = withStreams(structure, s): str =>
val transform = nodeView(s, str, roots, extra)
runTask(toRun, s, str, structure.index.triggers, config)(using transform)
}
val stop = System.currentTimeMillis
Complete(start, stop, result, newS)
}
val cacheSummary = cacheEventLog.summary
Complete(start, stop, result, cacheSummary, newS)
adpi2 marked this conversation as resolved.
Show resolved Hide resolved

def runTasks[A1](
s: State,
Expand All @@ -124,20 +124,22 @@ object Aggregation {
stop: Long,
extracted: Extracted,
success: Boolean,
log: Logger
): Unit = {
import extracted._
cacheSummary: String,
log: Logger,
): Unit =
import extracted.*
def get(key: SettingKey[Boolean]): Boolean =
(currentRef / key).get(structure.data) getOrElse true

if (get(showSuccess)) {
if (get(showTiming)) {
val msg = timingString(start, stop, structure.data, currentRef)
if (success) log.success(msg) else if (Terminal.get.isSuccessEnabled) log.error(msg)
} else if (success)
log.success("")
}
}
if get(showSuccess) then
if get(showTiming) then
val msg = timingString(start, stop, structure.data, currentRef) + (
if cacheSummary == "" then ""
else ", " + cacheSummary
)
if success then log.success(msg)
else if Terminal.get.isSuccessEnabled then log.error(msg)
else if success then log.success("")
else ()

private def timingString(
startTime: Long,
Expand All @@ -149,23 +151,19 @@ object Aggregation {
timing(format, startTime, endTime)
}

def timing(format: java.text.DateFormat, startTime: Long, endTime: Long): String = {
val nowString = format.format(new java.util.Date(endTime))
def timing(format: java.text.DateFormat, startTime: Long, endTime: Long): String =
val total = (endTime - startTime + 500) / 1000
val totalString = s"$total s" +
(if (total <= 60) ""
(if total <= 60 then ""
else {
val maybeHours = total / 3600 match {
val maybeHours = total / 3600 match
case 0 => ""
case h => f"$h%02d:"
}
val mins = f"${total % 3600 / 60}%02d"
val secs = f"${total % 60}%02d"
s" ($maybeHours$mins:$secs)"
})

s"Total time: $totalString, completed $nowString"
}
s"elapsed time: $totalString"

def defaultFormat: DateFormat = {
import java.text.DateFormat
Expand Down
16 changes: 16 additions & 0 deletions main/src/main/scala/sbt/internal/SysProp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package sbt
package internal

import java.io.File
import java.net.URI
import java.nio.file.{ Path, Paths }
import java.util.Locale

Expand Down Expand Up @@ -98,6 +99,21 @@ object SysProp {
def legacyTestReport: Boolean = getOrFalse("sbt.testing.legacyreport")
def semanticdb: Boolean = getOrFalse("sbt.semanticdb")
def forceServerStart: Boolean = getOrFalse("sbt.server.forcestart")
def remoteCache: Option[URI] = sys.props
.get("sbt.remote_cache")
.map(URI(_))
def remoteCacheTlsCertificate: Option[File] = sys.props
.get("sbt.remote_cache.tls_certificate")
.map(File(_))
def remoteCacheTlsClientCertificate: Option[File] = sys.props
.get("sbt.remote_cache.tls_client_certificate")
.map(File(_))
def remoteCacheTlsClientKey: Option[File] = sys.props
.get("sbt.remote_cache.tls_client_key")
.map(File(_))
def remoteCacheHeaders: List[String] = sys.props
.get("sbt.remote_cache.header")
.toList

def watchMode: String =
sys.props.get("sbt.watch.mode").getOrElse("auto")
Expand Down
18 changes: 9 additions & 9 deletions main/src/test/scala/sbt/internal/AggregationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ object AggregationSpec extends verify.BasicTestSuite {
val timing = Aggregation.timing(Aggregation.defaultFormat, 0, _: Long)

test("timing should format total time properly") {
assert(timing(101).startsWith("Total time: 0 s,"))
assert(timing(1000).startsWith("Total time: 1 s,"))
assert(timing(3000).startsWith("Total time: 3 s,"))
assert(timing(30399).startsWith("Total time: 30 s,"))
assert(timing(60399).startsWith("Total time: 60 s,"))
assert(timing(60699).startsWith("Total time: 61 s (01:01),"))
assert(timing(303099).startsWith("Total time: 303 s (05:03),"))
assert(timing(6003099).startsWith("Total time: 6003 s (01:40:03),"))
assert(timing(96003099).startsWith("Total time: 96003 s (26:40:03),"))
assert(timing(101).startsWith("elapsed time: 0 s"))
assert(timing(1000).startsWith("elapsed time: 1 s"))
assert(timing(3000).startsWith("elapsed time: 3 s"))
assert(timing(30399).startsWith("elapsed time: 30 s"))
assert(timing(60399).startsWith("elapsed time: 60 s"))
assert(timing(60699).startsWith("elapsed time: 61 s (01:01)"))
assert(timing(303099).startsWith("elapsed time: 303 s (05:03)"))
assert(timing(6003099).startsWith("elapsed time: 6003 s (01:40:03)"))
assert(timing(96003099).startsWith("elapsed time: 96003 s (26:40:03)"))
}
}
6 changes: 6 additions & 0 deletions project/ContrabandConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ object ContrabandConfig {
case "scalajson.ast.unsafe.JValue" | "sjsonnew.shaded.scalajson.ast.unsafe.JValue" => { _ =>
"sbt.internal.util.codec.JValueFormats" :: Nil
}
case "xsbti.HashedVirtualFileRef" => { _ =>
"sbt.internal.util.codec.HashedVirtualFileRefFormats" :: Nil
}
case "java.nio.ByteBuffer" => { _ =>
"sbt.internal.util.codec.ByteBufferFormats" :: Nil
}
}

/** Returns the list of formats required to encode the given `TpeRef`. */
Expand Down
4 changes: 3 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object Dependencies {
private val ioVersion = nightlyVersion.getOrElse("1.8.0")
private val lmVersion =
sys.props.get("sbt.build.lm.version").orElse(nightlyVersion).getOrElse("2.0.0-alpha13")
val zincVersion = nightlyVersion.getOrElse("2.0.0-alpha12")
val zincVersion = nightlyVersion.getOrElse("2.0.0-alpha13")

private val sbtIO = "org.scala-sbt" %% "io" % ioVersion

Expand Down Expand Up @@ -106,6 +106,8 @@ object Dependencies {
val junit = "junit" % "junit" % "4.13.1"
val scalaVerify = "com.eed3si9n.verify" %% "verify" % "1.0.0"
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
val remoteapis =
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08"

val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3

Expand Down