Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds run support for deploy_jar targets #14352

Merged
merged 10 commits into from Feb 7, 2022
3 changes: 2 additions & 1 deletion src/python/pants/backend/experimental/java/register.py
Expand Up @@ -12,7 +12,7 @@
JunitTestTarget,
)
from pants.backend.java.target_types import rules as target_types_rules
from pants.jvm import classpath, jdk_rules, resources
from pants.jvm import classpath, jdk_rules, resources, run_deploy_jar
from pants.jvm import util_rules as jvm_util_rules
from pants.jvm.dependency_inference import symbol_mapper
from pants.jvm.goals import lockfile
Expand Down Expand Up @@ -52,4 +52,5 @@ def rules():
*jdk_rules.rules(),
*target_types_rules(),
*jvm_tool.rules(),
*run_deploy_jar.rules(),
]
3 changes: 2 additions & 1 deletion src/python/pants/backend/experimental/scala/register.py
Expand Up @@ -14,7 +14,7 @@
)
from pants.backend.scala.target_types import rules as target_types_rules
from pants.backend.scala.test import scalatest
from pants.jvm import classpath, jdk_rules, resources
from pants.jvm import classpath, jdk_rules, resources, run_deploy_jar
from pants.jvm import util_rules as jvm_util_rules
from pants.jvm.goals import lockfile
from pants.jvm.package import deploy_jar
Expand Down Expand Up @@ -56,4 +56,5 @@ def rules():
*target_types_rules(),
*jvm_tool.rules(),
*resources.rules(),
*run_deploy_jar.rules(),
]
37 changes: 37 additions & 0 deletions src/python/pants/backend/scala/compile/scalac.py
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from itertools import chain

from pants.backend.java.target_types import JavaFieldSet, JavaGeneratorFieldSet, JavaSourceField
Expand Down Expand Up @@ -39,6 +40,11 @@ class CompileScalaSourceRequest(ClasspathEntryRequest):
field_sets_consume_only = (JavaFieldSet, JavaGeneratorFieldSet)


@dataclass(frozen=True)
class ScalaLibraryRequest:
version: str


@rule(desc="Compile with scalac")
async def compile_scala_source(
scala: ScalaSubsystem,
Expand Down Expand Up @@ -68,6 +74,17 @@ async def compile_scala_source(
exit_code=1,
)

all_dependency_jars = [
filename
for dependency in direct_dependency_classpath_entries
for filename in dependency.filenames
]
if not any(
filename.startswith("org.scala-lang_scala-library_") for filename in all_dependency_jars
):
scala_library = await Get(ClasspathEntry, ScalaLibraryRequest(scala.version))
direct_dependency_classpath_entries += (scala_library,)

Comment on lines +85 to +90
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should probably consider this to be a workaround for #14171, rather than a fix. If this ends up needing another iteration, adding a TODO here pointing to that ticket would be good.

component_members_with_sources = tuple(
t for t in request.component.members if t.has_field(SourcesField)
)
Expand Down Expand Up @@ -189,6 +206,26 @@ async def compile_scala_source(
)


@rule
async def fetch_scala_library(request: ScalaLibraryRequest) -> ClasspathEntry:
tcp = await Get(
ToolClasspath,
ToolClasspathRequest(
artifact_requirements=ArtifactRequirements.from_coordinates(
[
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=request.version,
),
]
),
),
)

return ClasspathEntry(tcp.digest, tcp.content.files)


def rules():
return [
*collect_rules(),
Expand Down
21 changes: 12 additions & 9 deletions src/python/pants/jvm/compile_test.py
Expand Up @@ -216,16 +216,19 @@ def test_compile_mixed(rule_runner: RuleRunner) -> None:
rendered_classpath = rule_runner.request(
RenderedClasspath, [Addresses([Address(spec_path="", target_name="main")])]
)
assert rendered_classpath.content == {
".Example.scala.main.scalac.jar": {
"META-INF/MANIFEST.MF",
"org/pantsbuild/example/Main$.class",
"org/pantsbuild/example/Main.class",
},
"lib.C.java.javac.jar": {
"org/pantsbuild/example/lib/C.class",
},

assert rendered_classpath.content[".Example.scala.main.scalac.jar"] == {
"META-INF/MANIFEST.MF",
"org/pantsbuild/example/Main$.class",
"org/pantsbuild/example/Main.class",
}
assert rendered_classpath.content["lib.C.java.javac.jar"] == {
"org/pantsbuild/example/lib/C.class",
}
assert any(
key.startswith("org.scala-lang_scala-library_") for key in rendered_classpath.content.keys()
)
assert len(rendered_classpath.content.keys()) == 3


@maybe_skip_jdk_test
Expand Down
27 changes: 23 additions & 4 deletions src/python/pants/jvm/jdk_rules.py
Expand Up @@ -93,7 +93,22 @@ async def setup_jdk(coursier: Coursier, jvm: JvmSubsystem, bash: BashBinary) ->
coursier_jdk_option = shlex.quote(f"--jvm={jvm.jdk}")
# NB: We `set +e` in the subshell to ensure that it exits as well.
# see https://unix.stackexchange.com/a/23099
java_home_command = " ".join(("set +e;", *coursier.args(["java-home", coursier_jdk_option])))

def prefixed(arg: str) -> str:
if arg.startswith("__"):
return f"${{PANTS_INTERNAL_ABSOLUTE_PREFIX}}{arg}"
chrisjrn marked this conversation as resolved.
Show resolved Hide resolved
else:
return arg

optionally_prefixed_coursier_args = [
prefixed(arg) for arg in coursier.args(["java-home", coursier_jdk_option])
]
java_home_command = " ".join(("set +e;", *optionally_prefixed_coursier_args))

env = {
"PANTS_INTERNAL_ABSOLUTE_PREFIX": "",
**coursier.env,
}

java_version_result = await Get(
FallibleProcessResult,
Expand All @@ -105,7 +120,7 @@ async def setup_jdk(coursier: Coursier, jvm: JvmSubsystem, bash: BashBinary) ->
),
append_only_caches=coursier.append_only_caches,
immutable_input_digests=coursier.immutable_input_digests,
env=coursier.env,
env=env,
description=f"Ensure download of JDK {coursier_jdk_option}.",
cache_scope=ProcessCacheScope.PER_RESTART_SUCCESSFUL,
level=LogLevel.DEBUG,
Expand Down Expand Up @@ -135,7 +150,7 @@ async def setup_jdk(coursier: Coursier, jvm: JvmSubsystem, bash: BashBinary) ->
{version_comment}
set -eu

/bin/ln -s "$({java_home_command})" "{JdkSetup.java_home}"
/bin/ln -s "$({java_home_command})" "${{PANTS_INTERNAL_ABSOLUTE_PREFIX}}{JdkSetup.java_home}"
exec "$@"
"""
)
Expand Down Expand Up @@ -232,7 +247,11 @@ async def jvm_process(bash: BashBinary, jdk_setup: JdkSetup, request: JvmProcess
**jdk_setup.immutable_input_digests,
**request.extra_immutable_input_digests,
}
env = {**jdk_setup.env, **request.extra_env}
env = {
"PANTS_INTERNAL_ABSOLUTE_PREFIX": "",
**jdk_setup.env,
**request.extra_env,
}

use_nailgun = []
if request.use_nailgun:
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/jvm/package/deploy_jar.py
Expand Up @@ -13,6 +13,7 @@
OutputPathField,
PackageFieldSet,
)
from pants.core.goals.run import RunFieldSet
from pants.core.util_rules.archive import ZipBinary
from pants.engine.addresses import Addresses
from pants.engine.fs import EMPTY_DIGEST, AddPrefix, CreateDigest, Digest, FileContent, MergeDigests
Expand All @@ -39,7 +40,7 @@


@dataclass(frozen=True)
class DeployJarFieldSet(PackageFieldSet):
class DeployJarFieldSet(PackageFieldSet, RunFieldSet):
required_fields = (
JvmMainClassNameField,
Dependencies,
Expand Down
136 changes: 136 additions & 0 deletions src/python/pants/jvm/run_deploy_jar.py
@@ -0,0 +1,136 @@
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import dataclasses
import logging
from dataclasses import dataclass
from typing import Iterable

from pants.core.goals.package import BuiltPackage
from pants.core.goals.run import RunFieldSet, RunRequest
from pants.engine.fs import EMPTY_DIGEST, Digest, MergeDigests
from pants.engine.internals.native_engine import AddPrefix
from pants.engine.process import Process, ProcessResult
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.unions import UnionRule
from pants.jvm.jdk_rules import JdkSetup, JvmProcess
from pants.jvm.package.deploy_jar import DeployJarFieldSet
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel

logger = logging.getLogger(__name__)


@dataclass
class __RuntimeJvm:
"""Allows Coursier to download a JDK into a Digest, rather than an append-only cache for use
with `pants run`.

This is a hideous stop-gap, which will no longer be necessary once `InteractiveProcess` supports
append-only caches.
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've expanded #13852 to cover this: if this goes through another iteration, it would be good to link there.

"""

digest: Digest


@rule(level=LogLevel.DEBUG)
async def create_deploy_jar_run_request(
field_set: DeployJarFieldSet,
runtime_jvm: __RuntimeJvm,
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear how this ends up used... it's being captured, but I don't see anything re-writing the process execution to use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jdk_setup: JdkSetup,
) -> RunRequest:

main_class = field_set.main_class.value
assert main_class is not None

package = await Get(BuiltPackage, DeployJarFieldSet, field_set)
assert len(package.artifacts) == 1
jar_path = package.artifacts[0].relpath
assert jar_path is not None

proc = await Get(
Process,
JvmProcess(
classpath_entries=[f"{{chroot}}/{jar_path}"],
argv=(main_class,),
input_digest=package.digest,
description=f"Run {main_class}.main(String[])",
use_nailgun=False,
),
)

support_digests = await MultiGet(
Get(Digest, AddPrefix(digest, prefix))
for prefix, digest in proc.immutable_input_digests.items()
)

support_digests += (runtime_jvm.digest,)

def prefixed(arg: str, prefixes: Iterable[str]) -> str:
if any(arg.startswith(prefix) for prefix in prefixes):
return f"{{chroot}}/{arg}"
else:
return arg

prefixes = (jdk_setup.bin_dir, jdk_setup.jdk_preparation_script, jdk_setup.java_home)
args = [prefixed(arg, prefixes) for arg in proc.argv]

env = {
**proc.env,
"PANTS_INTERNAL_ABSOLUTE_PREFIX": "{chroot}/",
}

# absolutify coursier cache envvars
for key in env:
if key.startswith("COURSIER"):
env[key] = prefixed(env[key], (jdk_setup.coursier.cache_dir,))
Comment on lines +71 to +88
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Half of this is happening in the JDK support code, and the other half is happening here... it would be good for them to refer to one another with TODOs at least... possibly referencing #14386.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complication here is the number of places where we're doing preprocessing of process arguments. This will definitely become easier once we no longer have to rewrite things


request_digest = await Get(
Digest,
MergeDigests(
[
proc.input_digest,
*support_digests,
]
),
)

return RunRequest(
digest=request_digest,
args=args,
extra_env=env,
)


@rule
async def ensure_jdk_for_pants_run(jdk_setup: JdkSetup) -> __RuntimeJvm:
# `tools.jar` is distributed with the JDK, so we can rely on it existing.
ensure_jvm_process = await Get(
Process,
JvmProcess(
classpath_entries=[f"{jdk_setup.java_home}/lib/tools.jar"],
argv=["com.sun.tools.javac.Main", "--version"],
input_digest=EMPTY_DIGEST,
description="Ensure download of JDK for `pants run` use",
),
)

# Do not treat the coursier JDK digest an append-only cache, so that we can capture the
# downloaded JDK in a `Digest`
new_append_only_caches = {
"coursier_archive": ".cache/arc",
"coursier_jvm": ".cache/jvm",
}

ensure_jvm_process = dataclasses.replace(
ensure_jvm_process,
append_only_caches=FrozenDict(new_append_only_caches),
output_directories=(".cache/jdk",),
)
ensure_jvm = await Get(ProcessResult, Process, ensure_jvm_process)

return __RuntimeJvm(ensure_jvm.output_digest)


def rules():
return [*collect_rules(), UnionRule(RunFieldSet, DeployJarFieldSet)]