Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scala_artifact #19128

Merged
merged 16 commits into from May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pants.toml
Expand Up @@ -173,7 +173,7 @@ args = ["--external-sources"]
args = ["-i 2", "-ci", "-sr"]

[pytest]
args = ["--no-header"]
args = ["--no-header", "-vv"]
execution_slot_var = "TEST_EXECUTION_SLOT"
install_from_resolve = "pytest"
requirements = ["//3rdparty/python:pytest"]
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/experimental/scala/register.py
Expand Up @@ -14,6 +14,7 @@
ScalaSourceTarget,
ScalatestTestsGeneratorTarget,
ScalatestTestTarget,
ScalaArtifactTarget
)
from pants.backend.scala.target_types import rules as target_types_rules
from pants.backend.scala.test import scalatest
Expand All @@ -32,6 +33,7 @@ def target_types():
ScalacPluginTarget,
ScalatestTestTarget,
ScalatestTestsGeneratorTarget,
ScalaArtifactTarget,
*jvm_common.target_types(),
*wrap_scala.target_types,
]
Expand Down
4 changes: 4 additions & 0 deletions src/python/pants/backend/scala/BUILD
Expand Up @@ -2,3 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()

python_tests(
name="tests",
)
156 changes: 155 additions & 1 deletion src/python/pants/backend/scala/target_types.py
Expand Up @@ -5,14 +5,19 @@

from dataclasses import dataclass

from pants.backend.scala.subsystems.scala import ScalaSubsystem
from pants.backend.scala.subsystems.scala_infer import ScalaInferSubsystem
from pants.build_graph.build_file_aliases import BuildFileAliases
from pants.core.goals.test import TestExtraEnvVarsField, TestTimeoutField
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
COMMON_TARGET_FIELDS,
AsyncFieldMixin,
BoolField,
Dependencies,
FieldSet,
GeneratedTargets,
GenerateTargetsRequest,
MultipleSourcesField,
OverridesField,
SingleSourceField,
Expand All @@ -22,15 +27,25 @@
TargetFilesGenerator,
TargetFilesGeneratorSettings,
TargetFilesGeneratorSettingsRequest,
TargetGenerator,
generate_file_based_overrides_field_help_message,
generate_multiple_sources_field_help_message,
)
from pants.engine.unions import UnionRule
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm import target_types as jvm_target_types
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import (
JunitTestExtraEnvVarsField,
JunitTestSourceField,
JunitTestTimeoutField,
JvmArtifactArtifactField,
JvmArtifactExcludeDependenciesField,
JvmArtifactExclusionRule,
JvmArtifactGroupField,
JvmArtifactPackagesField,
JvmArtifactResolveField,
JvmArtifactTarget,
JvmArtifactVersionField,
JvmJdkField,
JvmMainClassNameField,
JvmProvidesTypesField,
Expand Down Expand Up @@ -359,10 +374,149 @@ class ScalacPluginTarget(Target):
)


# -----------------------------------------------------------------------------------------------
# `scala_artifact` target
# -----------------------------------------------------------------------------------------------


class ScalaArtifactGroupField(JvmArtifactGroupField):
pass


class ScalaArtifactArtifactField(JvmArtifactArtifactField):
pass


class ScalaArtifactVersionField(JvmArtifactVersionField):
pass


@dataclass(frozen=True)
class ScalaArtifactExclusionRule(JvmArtifactExclusionRule):
alias = "scala_exclude"

full_crossversion: bool = False


class ScalaArtifactExcludeDependenciesField(JvmArtifactExcludeDependenciesField):
pass


class ScalaArtifactResolveField(JvmArtifactResolveField):
pass


class ScalaArtifactPackagesField(JvmArtifactPackagesField):
pass


class ScalaArtifactFullCrossversionField(BoolField):
alias = "full_crossversion"
default = False
help = help_text("If enabled, it will use the full Scala version in the artifact suffix.")


@dataclass(frozen=True)
class ScalaArtifactFieldSet(FieldSet):
group: ScalaArtifactGroupField
artifact: ScalaArtifactArtifactField
version: ScalaArtifactVersionField
packages: ScalaArtifactPackagesField
resolve: ScalaArtifactResolveField
excludes: ScalaArtifactExcludeDependenciesField
full_crossversion: ScalaArtifactFullCrossversionField

required_fields = (
ScalaArtifactGroupField,
ScalaArtifactArtifactField,
ScalaArtifactVersionField,
ScalaArtifactPackagesField,
)


class ScalaArtifactTarget(TargetGenerator):
alias = "scala_artifact"
core_fields = (
*COMMON_TARGET_FIELDS,
*ScalaArtifactFieldSet.required_fields,
ScalaArtifactResolveField,
ScalaArtifactExcludeDependenciesField,
ScalaArtifactFullCrossversionField,
JvmJdkField,
)
copied_fields = (
*COMMON_TARGET_FIELDS,
ScalaArtifactGroupField,
ScalaArtifactVersionField,
ScalaArtifactPackagesField,
)
moved_fields = (ScalaArtifactResolveField, JvmJdkField)


class GenerateJvmArtifactForScalaTargets(GenerateTargetsRequest):
generate_from = ScalaArtifactTarget


@rule
async def generate_jvm_artifact_targets(
request: GenerateJvmArtifactForScalaTargets,
jvm: JvmSubsystem,
scala: ScalaSubsystem,
union_membership: UnionMembership,
) -> GeneratedTargets:
field_set = ScalaArtifactFieldSet.create(request.generator)
scala_version = scala.version_for_resolve(field_set.resolve.normalized_value(jvm))
scala_version_parts = scala_version.split(".")

def scala_suffix(full_crossversion: bool) -> str:
if full_crossversion:
return scala_version
elif int(scala_version_parts[0]) >= 3:
return scala_version_parts[0]

return f"{scala_version_parts[0]}.{scala_version_parts[1]}"

exclude_dependencies_field = {}
if field_set.excludes.value:
exclusion_rules = []
for exclusion_rule in field_set.excludes.value:
if not isinstance(exclusion_rule, ScalaArtifactExclusionRule):
exclusion_rules.append(exclusion_rule)
else:
excluded_artifact_name = None
if exclusion_rule.artifact:
excluded_artifact_name = f"{exclusion_rule.artifact}_{scala_suffix(exclusion_rule.full_crossversion)}"
exclusion_rules.append(
JvmArtifactExclusionRule(
group=exclusion_rule.group, artifact=excluded_artifact_name
)
)
exclude_dependencies_field[JvmArtifactExcludeDependenciesField.alias] = exclusion_rules

artifact_name = f"{field_set.artifact.value}_{scala_suffix(field_set.full_crossversion.value)}"
jvm_artifact_target = JvmArtifactTarget(
{
**request.template,
JvmArtifactArtifactField.alias: artifact_name,
**exclude_dependencies_field,
},
request.generator.address.create_generated(artifact_name),
union_membership,
residence_dir=request.generator.address.spec_path,
)

return GeneratedTargets(request.generator, (jvm_artifact_target,))


def rules():
return (
*collect_rules(),
*jvm_target_types.rules(),
*ScalaFieldSet.jvm_rules(),
UnionRule(TargetFilesGeneratorSettingsRequest, ScalaSettingsRequest),
UnionRule(GenerateTargetsRequest, GenerateJvmArtifactForScalaTargets),
)


def build_file_aliases():
return BuildFileAliases(objects={ScalaArtifactExclusionRule.alias: ScalaArtifactExclusionRule})
151 changes: 151 additions & 0 deletions src/python/pants/backend/scala/target_types_test.py
@@ -0,0 +1,151 @@
# Copyright 2023 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

from textwrap import dedent

import pytest

from pants.backend.scala.target_types import ScalaArtifactExclusionRule, ScalaArtifactTarget
from pants.backend.scala.target_types import rules as target_types_rules
from pants.build_graph.address import Address
from pants.engine.internals.graph import _TargetParametrizations, _TargetParametrizationsRequest
from pants.engine.internals.parametrize import Parametrize
from pants.engine.rules import QueryRule
from pants.engine.target import Target
from pants.jvm import jvm_common
from pants.jvm.target_types import (
JvmArtifactArtifactField,
JvmArtifactExclusionRule,
JvmArtifactGroupField,
JvmArtifactResolveField,
JvmArtifactTarget,
JvmArtifactVersionField,
)
from pants.testutil.rule_runner import RuleRunner


@pytest.fixture
def rule_runner() -> RuleRunner:
rule_runner = RuleRunner(
target_types=[ScalaArtifactTarget, JvmArtifactTarget],
rules=[
*target_types_rules(),
*jvm_common.rules(),
QueryRule(_TargetParametrizations, [_TargetParametrizationsRequest]),
],
objects={
"parametrize": Parametrize,
JvmArtifactExclusionRule.alias: JvmArtifactExclusionRule,
ScalaArtifactExclusionRule.alias: ScalaArtifactExclusionRule,
},
)
return rule_runner


def assert_generated(
rule_runner: RuleRunner,
address: Address,
*,
build_content: str,
scala_versions_per_resolve: dict[str, str],
expected_targets: set[Target],
) -> None:
rule_runner.write_files({"BUILD": build_content})
rule_runner.set_options([f"--scala-version-for-resolve={repr(scala_versions_per_resolve)}"])

parametrizations = rule_runner.request(
_TargetParametrizations,
[
_TargetParametrizationsRequest(
address,
description_of_origin="tests",
),
],
)
assert expected_targets == {
t for parametrization in parametrizations for t in parametrization.parametrization.values()
}


def test_generate_jvm_artifact_based_on_resolve(rule_runner: RuleRunner) -> None:
assert_generated(
rule_runner,
Address("", target_name="test"),
build_content=dedent(
"""\
scala_artifact(
name="test",
group="com.example",
artifact="example-gen",
version="3.4.0",
resolve="current",
)
"""
),
scala_versions_per_resolve={"current": "2.13.10"},
expected_targets={
JvmArtifactTarget(
{
JvmArtifactGroupField.alias: "com.example",
JvmArtifactArtifactField.alias: "example-gen_2.13",
JvmArtifactVersionField.alias: "3.4.0",
JvmArtifactResolveField.alias: "current",
},
Address(
"",
target_name="test",
generated_name="example-gen_2.13",
),
),
},
)


def test_generate_jvm_artifacts_for_parametrized_resolve(rule_runner: RuleRunner) -> None:
alonsodomin marked this conversation as resolved.
Show resolved Hide resolved
assert_generated(
rule_runner,
Address("", target_name="test"),
build_content=dedent(
"""\
scala_artifact(
name="test",
group="com.example",
artifact="example-gen",
version="2.9.0",
resolve=parametrize("latest", "previous"),
)
"""
),
scala_versions_per_resolve={"latest": "3.3.0", "previous": "2.13.10"},
expected_targets={
JvmArtifactTarget(
{
JvmArtifactGroupField.alias: "com.example",
JvmArtifactArtifactField.alias: "example-gen_3",
JvmArtifactVersionField.alias: "2.9.0",
JvmArtifactResolveField.alias: "latest",
},
Address(
"",
target_name="test",
generated_name="example-gen_3",
parameters={"resolve": "latest"},
),
),
JvmArtifactTarget(
{
JvmArtifactGroupField.alias: "com.example",
JvmArtifactArtifactField.alias: "example-gen_2.13",
JvmArtifactVersionField.alias: "2.9.0",
JvmArtifactResolveField.alias: "previous",
},
Address(
"",
target_name="test",
generated_name="example-gen_2.13",
parameters={"resolve": "previous"},
),
),
},
)
3 changes: 2 additions & 1 deletion src/python/pants/jvm/resolve/common.py
Expand Up @@ -161,6 +161,7 @@ def from_jvm_artifact_target(cls, target: Target) -> ArtifactRequirement:
"`JvmArtifactFieldSet` fields present."
)

exclusion_rules = target[JvmArtifactExcludeDependenciesField].value or ()
return ArtifactRequirement(
coordinate=Coordinate(
group=target[JvmArtifactGroupField].value,
Expand All @@ -173,7 +174,7 @@ def from_jvm_artifact_target(cls, target: Target) -> ArtifactRequirement:
if target[JvmArtifactJarSourceField].value
else None
),
excludes=frozenset(target[JvmArtifactExcludeDependenciesField].value or []) or None,
excludes=frozenset(rule.to_exclude_str() for rule in exclusion_rules) or None,
)

def with_extra_excludes(self, *excludes: str) -> ArtifactRequirement:
Expand Down