diff --git a/src/python/pants/engine/internals/graph.py b/src/python/pants/engine/internals/graph.py index a38b5a73381..9c58d355df3 100644 --- a/src/python/pants/engine/internals/graph.py +++ b/src/python/pants/engine/internals/graph.py @@ -7,7 +7,7 @@ from collections import defaultdict, deque from dataclasses import dataclass from pathlib import PurePath -from typing import DefaultDict, Dict, List, Tuple, Union +from typing import DefaultDict, Dict, Iterable, List, Tuple, Type, Union from pants.base.exceptions import ResolveError from pants.base.specs import ( @@ -27,27 +27,47 @@ BuildFileAddress, ) from pants.engine.collection import Collection -from pants.engine.fs import MergeDigests, PathGlobs, Snapshot, SourcesSnapshot +from pants.engine.fs import ( + EMPTY_SNAPSHOT, + GlobExpansionConjunction, + GlobMatchErrorBehavior, + MergeDigests, + PathGlobs, + Snapshot, + SourcesSnapshot, +) from pants.engine.internals.target_adaptor import TargetAdaptor from pants.engine.rules import RootRule, rule from pants.engine.selectors import Get, MultiGet from pants.engine.target import ( Dependencies, DependenciesRequest, + FieldSet, + FieldSetWithOrigin, + GeneratedSources, + GenerateSourcesRequest, HydratedSources, HydrateSourcesRequest, + InferDependenciesRequest, + InferredDependencies, + InjectDependenciesRequest, + InjectedDependencies, RegisteredTargetTypes, Sources, Target, Targets, + TargetsToValidFieldSets, + TargetsToValidFieldSetsRequest, TargetsWithOrigins, TargetWithOrigin, TransitiveTarget, TransitiveTargets, UnrecognizedTargetTypeException, WrappedTarget, + _AbstractFieldSet, generate_subtarget, ) +from pants.engine.unions import UnionMembership from pants.option.global_options import GlobalOptions, OwnersNotFoundBehavior from pants.source.filespec import matches_filespec from pants.util.ordered_set import FrozenOrderedSet, OrderedSet @@ -252,7 +272,7 @@ def already_covered_by_original_addresses(file_name: str, generated_address: Add # ----------------------------------------------------------------------------------------------- -# FilesystemSpecs -> Addresses +# Specs -> Addresses # ----------------------------------------------------------------------------------------------- @@ -385,18 +405,374 @@ async def resolve_sources_snapshot(specs: Specs, global_options: GlobalOptions) return SourcesSnapshot(result) +# ----------------------------------------------------------------------------------------------- +# Resolve the Sources field +# ----------------------------------------------------------------------------------------------- + + +class AmbiguousCodegenImplementationsException(Exception): + """Exception for when there are multiple codegen implementations and it is ambiguous which to + use.""" + + def __init__( + self, + generators: Iterable[Type["GenerateSourcesRequest"]], + *, + for_sources_types: Iterable[Type["Sources"]], + ) -> None: + bulleted_list_sep = "\n * " + all_same_generator_paths = ( + len(set((generator.input, generator.output) for generator in generators)) == 1 + ) + example_generator = list(generators)[0] + input = example_generator.input.__name__ + if all_same_generator_paths: + output = example_generator.output.__name__ + possible_generators = sorted(generator.__name__ for generator in generators) + super().__init__( + f"Multiple of the registered code generators can generate {output} from {input}. " + "It is ambiguous which implementation to use.\n\nPossible implementations:" + f"{bulleted_list_sep}{bulleted_list_sep.join(possible_generators)}" + ) + else: + possible_output_types = sorted( + generator.output.__name__ + for generator in generators + if issubclass(generator.output, tuple(for_sources_types)) + ) + possible_generators_with_output = [ + f"{generator.__name__} -> {generator.output.__name__}" + for generator in sorted(generators, key=lambda generator: generator.output.__name__) + ] + super().__init__( + f"Multiple of the registered code generators can generate one of " + f"{possible_output_types} from {input}. It is ambiguous which implementation to " + f"use. This can happen when the call site requests too many different output types " + f"from the same original protocol sources.\n\nPossible implementations with their " + f"output type: {bulleted_list_sep}" + f"{bulleted_list_sep.join(possible_generators_with_output)}" + ) + + +@rule +async def hydrate_sources( + request: HydrateSourcesRequest, + glob_match_error_behavior: GlobMatchErrorBehavior, + union_membership: UnionMembership, +) -> HydratedSources: + sources_field = request.field + + # First, find if there are any code generators for the input `sources_field`. This will be used + # to determine if the sources_field is valid or not. + # We could alternatively use `sources_field.can_generate()`, but we want to error if there are + # 2+ generators due to ambiguity. + generate_request_types = union_membership.get(GenerateSourcesRequest) + relevant_generate_request_types = [ + generate_request_type + for generate_request_type in generate_request_types + if isinstance(sources_field, generate_request_type.input) + and issubclass(generate_request_type.output, request.for_sources_types) + ] + if request.enable_codegen and len(relevant_generate_request_types) > 1: + raise AmbiguousCodegenImplementationsException( + relevant_generate_request_types, for_sources_types=request.for_sources_types + ) + generate_request_type = next(iter(relevant_generate_request_types), None) + + # Now, determine if any of the `for_sources_types` may be used, either because the + # sources_field is a direct subclass or can be generated into one of the valid types. + def compatible_with_sources_field(valid_type: Type[Sources]) -> bool: + is_instance = isinstance(sources_field, valid_type) + can_be_generated = ( + request.enable_codegen + and generate_request_type is not None + and issubclass(generate_request_type.output, valid_type) + ) + return is_instance or can_be_generated + + sources_type = next( + ( + valid_type + for valid_type in request.for_sources_types + if compatible_with_sources_field(valid_type) + ), + None, + ) + if sources_type is None: + return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=None) + + # Now, hydrate the `globs`. Even if we are going to use codegen, we will need the original + # protocol sources to be hydrated. + globs = sources_field.sanitized_raw_value + if globs is None: + return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=sources_type) + + conjunction = ( + GlobExpansionConjunction.all_match + if not sources_field.default or (set(globs) != set(sources_field.default)) + else GlobExpansionConjunction.any_match + ) + snapshot = await Get( + Snapshot, + PathGlobs( + (sources_field.prefix_glob_with_address(glob) for glob in globs), + conjunction=conjunction, + glob_match_error_behavior=glob_match_error_behavior, + # TODO(#9012): add line number referring to the sources field. When doing this, we'll + # likely need to `await Get(BuildFileAddress](Address)`. + description_of_origin=( + f"{sources_field.address}'s `{sources_field.alias}` field" + if glob_match_error_behavior != GlobMatchErrorBehavior.ignore + else None + ), + ), + ) + sources_field.validate_snapshot(snapshot) + + # Finally, return if codegen is not in use; otherwise, run the relevant code generator. + if not request.enable_codegen or generate_request_type is None: + return HydratedSources(snapshot, sources_field.filespec, sources_type=sources_type) + wrapped_protocol_target = await Get(WrappedTarget, Address, sources_field.address) + generated_sources = await Get( + GeneratedSources, + GenerateSourcesRequest, + generate_request_type(snapshot, wrapped_protocol_target.target), + ) + return HydratedSources( + generated_sources.snapshot, sources_field.filespec, sources_type=sources_type + ) + + +# ----------------------------------------------------------------------------------------------- +# Resolve the Dependencies field +# ----------------------------------------------------------------------------------------------- + + +class AmbiguousDependencyInferenceException(Exception): + """Exception for when there are multiple dependency inference implementations and it is + ambiguous which to use.""" + + def __init__( + self, + implementations: Iterable[Type["InferDependenciesRequest"]], + *, + from_sources_type: Type["Sources"], + ) -> None: + bulleted_list_sep = "\n * " + possible_implementations = sorted(impl.__name__ for impl in implementations) + super().__init__( + f"Multiple of the registered dependency inference implementations can infer " + f"dependencies from {from_sources_type.__name__}. It is ambiguous which " + "implementation to use.\n\nPossible implementations:" + f"{bulleted_list_sep}{bulleted_list_sep.join(possible_implementations)}" + ) + + +@rule +async def resolve_dependencies( + request: DependenciesRequest, union_membership: UnionMembership, global_options: GlobalOptions +) -> Addresses: + provided = [ + Address.parse( + dep, + relative_to=request.field.address.spec_path, + subproject_roots=global_options.options.subproject_roots, + ) + for dep in request.field.sanitized_raw_value or () + ] + + # Inject any dependencies. This is determined by the `request.field` class. For example, if + # there is a rule to inject for FortranDependencies, then FortranDependencies and any subclass + # of FortranDependencies will use that rule. + inject_request_types = union_membership.get(InjectDependenciesRequest) + injected = await MultiGet( + Get(InjectedDependencies, InjectDependenciesRequest, inject_request_type(request.field)) + for inject_request_type in inject_request_types + if isinstance(request.field, inject_request_type.inject_for) + ) + + inference_request_types = union_membership.get(InferDependenciesRequest) + inferred = InferredDependencies() + if global_options.options.dependency_inference and inference_request_types: + # Dependency inference is solely determined by the `Sources` field for a Target, so we + # re-resolve the original target to inspect its `Sources` field, if any. + wrapped_tgt = await Get(WrappedTarget, Address, request.field.address) + sources_field = wrapped_tgt.target.get(Sources) + relevant_inference_request_types = [ + inference_request_type + for inference_request_type in inference_request_types + if isinstance(sources_field, inference_request_type.infer_from) + ] + if relevant_inference_request_types: + if len(relevant_inference_request_types) > 1: + raise AmbiguousDependencyInferenceException( + relevant_inference_request_types, from_sources_type=type(sources_field) + ) + inference_request_type = relevant_inference_request_types[0] + inferred = await Get( + InferredDependencies, + InferDependenciesRequest, + inference_request_type(sources_field), + ) + + return Addresses(sorted([*provided, *itertools.chain.from_iterable(injected), *inferred])) + + +# ----------------------------------------------------------------------------------------------- +# Find valid field sets +# ----------------------------------------------------------------------------------------------- + + +class NoValidTargetsException(Exception): + def __init__( + self, + targets_with_origins: TargetsWithOrigins, + *, + valid_target_types: Iterable[Type[Target]], + goal_description: str, + ) -> None: + valid_target_aliases = sorted({target_type.alias for target_type in valid_target_types}) + invalid_target_aliases = sorted({tgt.alias for tgt in targets_with_origins.targets}) + specs = sorted( + { + target_with_origin.origin.to_spec_string() + for target_with_origin in targets_with_origins + } + ) + bulleted_list_sep = "\n * " + super().__init__( + f"{goal_description.capitalize()} only works with the following target types:" + f"{bulleted_list_sep}{bulleted_list_sep.join(valid_target_aliases)}\n\n" + f"You specified `{' '.join(specs)}`, which only included the following target types:" + f"{bulleted_list_sep}{bulleted_list_sep.join(invalid_target_aliases)}" + ) + + @classmethod + def create_from_field_sets( + cls, + targets_with_origins: TargetsWithOrigins, + *, + field_set_types: Iterable[Type[_AbstractFieldSet]], + goal_description: str, + union_membership: UnionMembership, + registered_target_types: RegisteredTargetTypes, + ) -> "NoValidTargetsException": + valid_target_types = { + target_type + for field_set_type in field_set_types + for target_type in field_set_type.valid_target_types( + registered_target_types.types, union_membership=union_membership + ) + } + return cls( + targets_with_origins, + valid_target_types=valid_target_types, + goal_description=goal_description, + ) + + +class TooManyTargetsException(Exception): + def __init__(self, targets: Iterable[Target], *, goal_description: str) -> None: + bulleted_list_sep = "\n * " + addresses = sorted(tgt.address.spec for tgt in targets) + super().__init__( + f"{goal_description.capitalize()} only works with one valid target, but was given " + f"multiple valid targets:{bulleted_list_sep}{bulleted_list_sep.join(addresses)}\n\n" + "Please select one of these targets to run." + ) + + +class AmbiguousImplementationsException(Exception): + """Exception for when a single target has multiple valid FieldSets, but the goal only expects + there to be one FieldSet.""" + + def __init__( + self, target: Target, field_sets: Iterable[_AbstractFieldSet], *, goal_description: str, + ) -> None: + # TODO: improve this error message. A better error message would explain to users how they + # can resolve the issue. + possible_field_sets_types = sorted(field_set.__class__.__name__ for field_set in field_sets) + bulleted_list_sep = "\n * " + super().__init__( + f"Multiple of the registered implementations for {goal_description} work for " + f"{target.address} (target type {repr(target.alias)}). It is ambiguous which " + "implementation to use.\n\nPossible implementations:" + f"{bulleted_list_sep}{bulleted_list_sep.join(possible_field_sets_types)}" + ) + + +@rule +def find_valid_field_sets( + request: TargetsToValidFieldSetsRequest, + targets_with_origins: TargetsWithOrigins, + union_membership: UnionMembership, + registered_target_types: RegisteredTargetTypes, +) -> TargetsToValidFieldSets: + field_set_types: Iterable[ + Union[Type[FieldSet], Type[FieldSetWithOrigin]] + ] = union_membership.union_rules[request.field_set_superclass] + targets_to_valid_field_sets = {} + for tgt_with_origin in targets_with_origins: + valid_field_sets = [ + ( + field_set_type.create(tgt_with_origin) + if issubclass(field_set_type, FieldSetWithOrigin) + else field_set_type.create(tgt_with_origin.target) + ) + for field_set_type in field_set_types + if field_set_type.is_valid(tgt_with_origin.target) + ] + if valid_field_sets: + targets_to_valid_field_sets[tgt_with_origin] = valid_field_sets + if request.error_if_no_valid_targets and not targets_to_valid_field_sets: + raise NoValidTargetsException.create_from_field_sets( + targets_with_origins, + field_set_types=field_set_types, + goal_description=request.goal_description, + union_membership=union_membership, + registered_target_types=registered_target_types, + ) + result = TargetsToValidFieldSets(targets_to_valid_field_sets) + if not request.expect_single_field_set: + return result + if len(result.targets) > 1: + raise TooManyTargetsException(result.targets, goal_description=request.goal_description) + if len(result.field_sets) > 1: + raise AmbiguousImplementationsException( + result.targets[0], result.field_sets, goal_description=request.goal_description + ) + return result + + def rules(): return [ + # Address -> Target resolve_target, resolve_targets, + # AddressWithOrigin -> TargetWithOrigin resolve_target_with_origin, resolve_targets_with_origins, + # TransitiveTargets transitive_target, transitive_targets, + # Owners find_owners, + RootRule(OwnersRequest), + # Specs -> AddressesWithOrigins addresses_with_origins_from_filesystem_specs, - resolve_sources_snapshot, resolve_addresses_with_origins, RootRule(Specs), - RootRule(OwnersRequest), + # SourcesSnapshot + resolve_sources_snapshot, + # Sources field + hydrate_sources, + RootRule(HydrateSourcesRequest), + # Dependencies field + resolve_dependencies, + RootRule(DependenciesRequest), + RootRule(InjectDependenciesRequest), + RootRule(InferDependenciesRequest), + # FieldSets + find_valid_field_sets, + RootRule(TargetsToValidFieldSetsRequest), ] diff --git a/src/python/pants/engine/internals/graph_test.py b/src/python/pants/engine/internals/graph_test.py index e65774635ee..ab50e1b5ff8 100644 --- a/src/python/pants/engine/internals/graph_test.py +++ b/src/python/pants/engine/internals/graph_test.py @@ -1,7 +1,11 @@ # Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import itertools +from dataclasses import dataclass +from pathlib import PurePath from textwrap import dedent +from typing import Iterable, List, Type import pytest @@ -14,21 +18,51 @@ SingleAddress, ) from pants.engine.addresses import Address, Addresses, AddressesWithOrigins, AddressWithOrigin -from pants.engine.fs import SourcesSnapshot -from pants.engine.internals.graph import Owners, OwnersRequest +from pants.engine.fs import ( + CreateDigest, + Digest, + DigestContents, + FileContent, + Snapshot, + SourcesSnapshot, +) +from pants.engine.internals.graph import ( + AmbiguousCodegenImplementationsException, + AmbiguousImplementationsException, + NoValidTargetsException, + Owners, + OwnersRequest, + TooManyTargetsException, +) from pants.engine.internals.scheduler import ExecutionError -from pants.engine.rules import RootRule -from pants.engine.selectors import Params +from pants.engine.rules import RootRule, rule +from pants.engine.selectors import Get, Params from pants.engine.target import ( Dependencies, DependenciesRequest, + FieldSet, + FieldSetWithOrigin, + GeneratedSources, + GenerateSourcesRequest, + HydratedSources, + HydrateSourcesRequest, + InferDependenciesRequest, + InferredDependencies, + InjectDependenciesRequest, + InjectedDependencies, Sources, + Tags, Target, Targets, + TargetsToValidFieldSets, + TargetsToValidFieldSetsRequest, + TargetsWithOrigins, + TargetWithOrigin, TransitiveTarget, TransitiveTargets, WrappedTarget, ) +from pants.engine.unions import UnionMembership, UnionRule, union from pants.init.specs_calculator import SpecsCalculator from pants.testutil.option.util import create_options_bootstrapper from pants.testutil.test_base import TestBase @@ -43,12 +77,7 @@ class MockTarget(Target): class GraphTest(TestBase): @classmethod def rules(cls): - return ( - *super().rules(), - RootRule(Addresses), - RootRule(WrappedTarget), - RootRule(FilesystemSpecs), - ) + return (*super().rules(), RootRule(Addresses), RootRule(WrappedTarget)) @classmethod def target_types(cls): @@ -130,6 +159,12 @@ def test_resolve_sources_snapshot(self) -> None: ) assert result.snapshot.files == ("demo/BUILD", "demo/f1.txt", "demo/f2.txt") + +class TestOwners(TestBase): + @classmethod + def target_types(cls): + return (MockTarget,) + def test_owners_source_file_does_not_exist(self) -> None: """Test when a source file belongs to a target, even though the file does not actually exist. @@ -213,6 +248,16 @@ def test_owners_build_file(self) -> None: Address("demo", "f2_second"), } + +class TestSpecsToAddresses(TestBase): + @classmethod + def rules(cls): + return (*super().rules(), RootRule(Addresses), RootRule(FilesystemSpecs)) + + @classmethod + def target_types(cls): + return (MockTarget,) + def test_filesystem_specs_literal_file(self) -> None: self.create_files("demo", ["f1.txt", "f2.txt"]) self.add_to_build_file("demo", "target(sources=['*.txt'])") @@ -380,3 +425,557 @@ def test_resolve_addresses(self) -> None: origin=SingleAddress("multiple_files", "multiple_files"), ), } + + +# ----------------------------------------------------------------------------------------------- +# Test FieldSets. Also see `engine/target_test.py`. +# ----------------------------------------------------------------------------------------------- + + +class FortranSources(Sources): + pass + + +class FortranTarget(Target): + alias = "fortran_target" + core_fields = (FortranSources, Tags) + + +class TestFindValidFieldSets(TestBase): + class InvalidTarget(Target): + alias = "invalid_target" + core_fields = () + + @classmethod + def target_types(cls): + return [FortranTarget, cls.InvalidTarget] + + @union + class FieldSetSuperclass(FieldSet): + pass + + @dataclass(frozen=True) + class FieldSetSubclass1(FieldSetSuperclass): + required_fields = (FortranSources,) + + sources: FortranSources + + @dataclass(frozen=True) + class FieldSetSubclass2(FieldSetSuperclass): + required_fields = (FortranSources,) + + sources: FortranSources + + @union + class FieldSetSuperclassWithOrigin(FieldSetWithOrigin): + pass + + class FieldSetSubclassWithOrigin(FieldSetSuperclassWithOrigin): + required_fields = (FortranSources,) + + sources: FortranSources + + @classmethod + def rules(cls): + return ( + *super().rules(), + RootRule(TargetsWithOrigins), + UnionRule(cls.FieldSetSuperclass, cls.FieldSetSubclass1), + UnionRule(cls.FieldSetSuperclass, cls.FieldSetSubclass2), + UnionRule(cls.FieldSetSuperclassWithOrigin, cls.FieldSetSubclassWithOrigin), + ) + + def test_find_valid_field_sets(self) -> None: + origin = FilesystemLiteralSpec("f.txt") + valid_tgt = FortranTarget({}, address=Address.parse(":valid")) + valid_tgt_with_origin = TargetWithOrigin(valid_tgt, origin) + invalid_tgt = self.InvalidTarget({}, address=Address.parse(":invalid")) + invalid_tgt_with_origin = TargetWithOrigin(invalid_tgt, origin) + + def find_valid_field_sets( + superclass: Type, + targets_with_origins: Iterable[TargetWithOrigin], + *, + error_if_no_valid_targets: bool = False, + expect_single_config: bool = False, + ) -> TargetsToValidFieldSets: + request = TargetsToValidFieldSetsRequest( + superclass, + goal_description="fake", + error_if_no_valid_targets=error_if_no_valid_targets, + expect_single_field_set=expect_single_config, + ) + return self.request_single_product( + TargetsToValidFieldSets, Params(request, TargetsWithOrigins(targets_with_origins),), + ) + + valid = find_valid_field_sets( + self.FieldSetSuperclass, [valid_tgt_with_origin, invalid_tgt_with_origin] + ) + assert valid.targets == (valid_tgt,) + assert valid.targets_with_origins == (valid_tgt_with_origin,) + assert valid.field_sets == ( + self.FieldSetSubclass1.create(valid_tgt), + self.FieldSetSubclass2.create(valid_tgt), + ) + + with pytest.raises(ExecutionError) as exc: + find_valid_field_sets( + self.FieldSetSuperclass, [valid_tgt_with_origin], expect_single_config=True + ) + assert AmbiguousImplementationsException.__name__ in str(exc.value) + + with pytest.raises(ExecutionError) as exc: + find_valid_field_sets( + self.FieldSetSuperclass, + [ + valid_tgt_with_origin, + TargetWithOrigin(FortranTarget({}, address=Address.parse(":valid2")), origin), + ], + expect_single_config=True, + ) + assert TooManyTargetsException.__name__ in str(exc.value) + + no_valid_targets = find_valid_field_sets(self.FieldSetSuperclass, [invalid_tgt_with_origin]) + assert no_valid_targets.targets == () + assert no_valid_targets.targets_with_origins == () + assert no_valid_targets.field_sets == () + + with pytest.raises(ExecutionError) as exc: + find_valid_field_sets( + self.FieldSetSuperclass, [invalid_tgt_with_origin], error_if_no_valid_targets=True + ) + assert NoValidTargetsException.__name__ in str(exc.value) + + valid_with_origin = find_valid_field_sets( + self.FieldSetSuperclassWithOrigin, [valid_tgt_with_origin, invalid_tgt_with_origin] + ) + assert valid_with_origin.targets == (valid_tgt,) + assert valid_with_origin.targets_with_origins == (valid_tgt_with_origin,) + assert valid_with_origin.field_sets == ( + self.FieldSetSubclassWithOrigin.create(valid_tgt_with_origin), + ) + + +# ----------------------------------------------------------------------------------------------- +# Test the Sources field, including codegen. Also see `engine/target_test.py`. +# ----------------------------------------------------------------------------------------------- + + +class TestSources(TestBase): + @classmethod + def rules(cls): + return (*super().rules(), RootRule(HydrateSourcesRequest)) + + def test_normal_hydration(self) -> None: + addr = Address.parse("src/fortran:lib") + self.create_files("src/fortran", files=["f1.f95", "f2.f95", "f1.f03", "ignored.f03"]) + sources = Sources(["f1.f95", "*.f03", "!ignored.f03", "!**/ignore*"], address=addr) + hydrated_sources = self.request_single_product( + HydratedSources, HydrateSourcesRequest(sources) + ) + assert hydrated_sources.snapshot.files == ("src/fortran/f1.f03", "src/fortran/f1.f95") + + # Also test that the Filespec is correct. This does not need hydration to be calculated. + assert ( + sources.filespec + == { + "includes": ["src/fortran/*.f03", "src/fortran/f1.f95"], + "excludes": ["src/fortran/**/ignore*", "src/fortran/ignored.f03"], + } + == hydrated_sources.filespec + ) + + def test_output_type(self) -> None: + class SourcesSubclass(Sources): + pass + + addr = Address.parse(":lib") + self.create_files("", files=["f1.f95"]) + + valid_sources = SourcesSubclass(["*"], address=addr) + hydrated_valid_sources = self.request_single_product( + HydratedSources, + HydrateSourcesRequest(valid_sources, for_sources_types=[SourcesSubclass]), + ) + assert hydrated_valid_sources.snapshot.files == ("f1.f95",) + assert hydrated_valid_sources.sources_type == SourcesSubclass + + invalid_sources = Sources(["*"], address=addr) + hydrated_invalid_sources = self.request_single_product( + HydratedSources, + HydrateSourcesRequest(invalid_sources, for_sources_types=[SourcesSubclass]), + ) + assert hydrated_invalid_sources.snapshot.files == () + assert hydrated_invalid_sources.sources_type is None + + def test_unmatched_globs(self) -> None: + self.create_files("", files=["f1.f95"]) + sources = Sources(["non_existent.f95"], address=Address.parse(":lib")) + with pytest.raises(ExecutionError) as exc: + self.request_single_product(HydratedSources, HydrateSourcesRequest(sources)) + assert "Unmatched glob" in str(exc.value) + assert "//:lib" in str(exc.value) + assert "non_existent.f95" in str(exc.value) + + def test_default_globs(self) -> None: + class DefaultSources(Sources): + default = ("default.f95", "default.f03", "*.f08", "!ignored.f08") + + addr = Address.parse("src/fortran:lib") + # NB: Not all globs will be matched with these files, specifically `default.f03` will not + # be matched. This is intentional to ensure that we use `any` glob conjunction rather + # than the normal `all` conjunction. + self.create_files("src/fortran", files=["default.f95", "f1.f08", "ignored.f08"]) + sources = DefaultSources(None, address=addr) + assert set(sources.sanitized_raw_value or ()) == set(DefaultSources.default) + + hydrated_sources = self.request_single_product( + HydratedSources, HydrateSourcesRequest(sources) + ) + assert hydrated_sources.snapshot.files == ("src/fortran/default.f95", "src/fortran/f1.f08") + + def test_expected_file_extensions(self) -> None: + class ExpectedExtensionsSources(Sources): + expected_file_extensions = (".f95", ".f03") + + addr = Address.parse("src/fortran:lib") + self.create_files("src/fortran", files=["s.f95", "s.f03", "s.f08"]) + sources = ExpectedExtensionsSources(["s.f*"], address=addr) + with pytest.raises(ExecutionError) as exc: + self.request_single_product(HydratedSources, HydrateSourcesRequest(sources)) + assert "s.f08" in str(exc.value) + assert str(addr) in str(exc.value) + + # Also check that we support valid sources + valid_sources = ExpectedExtensionsSources(["s.f95"], address=addr) + assert self.request_single_product( + HydratedSources, HydrateSourcesRequest(valid_sources) + ).snapshot.files == ("src/fortran/s.f95",) + + def test_expected_num_files(self) -> None: + class ExpectedNumber(Sources): + expected_num_files = 2 + + class ExpectedRange(Sources): + # We allow for 1 or 3 files + expected_num_files = range(1, 4, 2) + + self.create_files("", files=["f1.txt", "f2.txt", "f3.txt", "f4.txt"]) + + def hydrate(sources_cls: Type[Sources], sources: Iterable[str]) -> HydratedSources: + return self.request_single_product( + HydratedSources, + HydrateSourcesRequest(sources_cls(sources, address=Address.parse(":example"))), + ) + + with pytest.raises(ExecutionError) as exc: + hydrate(ExpectedNumber, []) + assert "must have 2 files" in str(exc.value) + with pytest.raises(ExecutionError) as exc: + hydrate(ExpectedRange, ["f1.txt", "f2.txt"]) + assert "must have 1 or 3 files" in str(exc.value) + + # Also check that we support valid # files. + assert hydrate(ExpectedNumber, ["f1.txt", "f2.txt"]).snapshot.files == ("f1.txt", "f2.txt") + assert hydrate(ExpectedRange, ["f1.txt"]).snapshot.files == ("f1.txt",) + assert hydrate(ExpectedRange, ["f1.txt", "f2.txt", "f3.txt"]).snapshot.files == ( + "f1.txt", + "f2.txt", + "f3.txt", + ) + + +class SmalltalkSources(Sources): + pass + + +class AvroSources(Sources): + pass + + +class AvroLibrary(Target): + alias = "avro_library" + core_fields = (AvroSources,) + + +class GenerateSmalltalkFromAvroRequest(GenerateSourcesRequest): + input = AvroSources + output = SmalltalkSources + + +@rule +async def generate_smalltalk_from_avro( + request: GenerateSmalltalkFromAvroRequest, +) -> GeneratedSources: + protocol_files = request.protocol_sources.files + + def generate_fortran(fp: str) -> FileContent: + parent = str(PurePath(fp).parent).replace("src/avro", "src/smalltalk") + file_name = f"{PurePath(fp).stem}.st" + return FileContent(str(PurePath(parent, file_name)), b"Generated") + + result = await Get(Snapshot, CreateDigest([generate_fortran(fp) for fp in protocol_files])) + return GeneratedSources(result) + + +class TestCodegen(TestBase): + @classmethod + def rules(cls): + return ( + *super().rules(), + generate_smalltalk_from_avro, + RootRule(GenerateSmalltalkFromAvroRequest), + RootRule(HydrateSourcesRequest), + UnionRule(GenerateSourcesRequest, GenerateSmalltalkFromAvroRequest), + ) + + @classmethod + def target_types(cls): + return [AvroLibrary] + + def setUp(self) -> None: + self.address = Address.parse("src/avro:lib") + self.create_files("src/avro", files=["f.avro"]) + self.add_to_build_file("src/avro", "avro_library(name='lib', sources=['*.avro'])") + self.union_membership = self.request_single_product(UnionMembership, Params()) + + def test_generate_sources(self) -> None: + protocol_sources = AvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(SmalltalkSources, self.union_membership) is True + + # First, get the original protocol sources. + hydrated_protocol_sources = self.request_single_product( + HydratedSources, HydrateSourcesRequest(protocol_sources) + ) + assert hydrated_protocol_sources.snapshot.files == ("src/avro/f.avro",) + + # Test directly feeding the protocol sources into the codegen rule. + wrapped_tgt = self.request_single_product(WrappedTarget, self.address) + generated_sources = self.request_single_product( + GeneratedSources, + GenerateSmalltalkFromAvroRequest( + hydrated_protocol_sources.snapshot, wrapped_tgt.target + ), + ) + assert generated_sources.snapshot.files == ("src/smalltalk/f.st",) + + # Test that HydrateSourcesRequest can also be used. + generated_via_hydrate_sources = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[SmalltalkSources], enable_codegen=True + ), + ) + assert generated_via_hydrate_sources.snapshot.files == ("src/smalltalk/f.st",) + assert generated_via_hydrate_sources.sources_type == SmalltalkSources + + def test_works_with_subclass_fields(self) -> None: + class CustomAvroSources(AvroSources): + pass + + protocol_sources = CustomAvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(SmalltalkSources, self.union_membership) is True + generated = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[SmalltalkSources], enable_codegen=True + ), + ) + assert generated.snapshot.files == ("src/smalltalk/f.st",) + + def test_cannot_generate_language(self) -> None: + class AdaSources(Sources): + pass + + protocol_sources = AvroSources(["*.avro"], address=self.address) + assert protocol_sources.can_generate(AdaSources, self.union_membership) is False + generated = self.request_single_product( + HydratedSources, + HydrateSourcesRequest( + protocol_sources, for_sources_types=[AdaSources], enable_codegen=True + ), + ) + assert generated.snapshot.files == () + assert generated.sources_type is None + + def test_ambiguous_implementations_exception(self) -> None: + # This error message is quite complex. We test that it correctly generates the message. + class SmalltalkGenerator1(GenerateSourcesRequest): + input = AvroSources + output = SmalltalkSources + + class SmalltalkGenerator2(GenerateSourcesRequest): + input = AvroSources + output = SmalltalkSources + + class AdaSources(Sources): + pass + + class AdaGenerator(GenerateSourcesRequest): + input = AvroSources + output = AdaSources + + class IrrelevantSources(Sources): + pass + + # Test when all generators have the same input and output. + exc = AmbiguousCodegenImplementationsException( + [SmalltalkGenerator1, SmalltalkGenerator2], for_sources_types=[SmalltalkSources] + ) + assert "can generate SmalltalkSources from AvroSources" in str(exc) + assert "* SmalltalkGenerator1" in str(exc) + assert "* SmalltalkGenerator2" in str(exc) + + # Test when the generators have different input and output, which usually happens because + # the call site used too expansive of a `for_sources_types` argument. + exc = AmbiguousCodegenImplementationsException( + [SmalltalkGenerator1, AdaGenerator], + for_sources_types=[SmalltalkSources, AdaSources, IrrelevantSources], + ) + assert "can generate one of ['AdaSources', 'SmalltalkSources'] from AvroSources" in str(exc) + assert "IrrelevantSources" not in str(exc) + assert "* SmalltalkGenerator1 -> SmalltalkSources" in str(exc) + assert "* AdaGenerator -> AdaSources" in str(exc) + + +# ----------------------------------------------------------------------------------------------- +# Test the Dependencies field. Also see `engine/target_test.py`. +# ----------------------------------------------------------------------------------------------- + + +class SmalltalkDependencies(Dependencies): + pass + + +class CustomSmalltalkDependencies(SmalltalkDependencies): + pass + + +class InjectSmalltalkDependencies(InjectDependenciesRequest): + inject_for = SmalltalkDependencies + + +class InjectCustomSmalltalkDependencies(InjectDependenciesRequest): + inject_for = CustomSmalltalkDependencies + + +@rule +def inject_smalltalk_deps(_: InjectSmalltalkDependencies) -> InjectedDependencies: + return InjectedDependencies([Address.parse("//:injected")]) + + +@rule +def inject_custom_smalltalk_deps(_: InjectCustomSmalltalkDependencies) -> InjectedDependencies: + return InjectedDependencies([Address.parse("//:custom_injected")]) + + +class SmalltalkLibrarySources(SmalltalkSources): + pass + + +class SmalltalkLibrary(Target): + alias = "smalltalk" + core_fields = (Dependencies, SmalltalkLibrarySources) + + +class InferSmalltalkDependencies(InferDependenciesRequest): + infer_from = SmalltalkSources + + +@rule +async def infer_smalltalk_dependencies(request: InferSmalltalkDependencies) -> InferredDependencies: + # To demo an inference rule, we simply treat each `sources` file to contain a list of + # addresses, one per line. + hydrated_sources = await Get(HydratedSources, HydrateSourcesRequest(request.sources_field)) + digest_contents = await Get(DigestContents, Digest, hydrated_sources.snapshot.digest) + all_lines = itertools.chain.from_iterable( + file_content.content.decode().splitlines() for file_content in digest_contents + ) + return InferredDependencies(Address.parse(line) for line in all_lines) + + +class TestDependencies(TestBase): + @classmethod + def rules(cls): + return ( + *super().rules(), + RootRule(DependenciesRequest), + inject_smalltalk_deps, + inject_custom_smalltalk_deps, + infer_smalltalk_dependencies, + UnionRule(InjectDependenciesRequest, InjectSmalltalkDependencies), + UnionRule(InjectDependenciesRequest, InjectCustomSmalltalkDependencies), + UnionRule(InferDependenciesRequest, InferSmalltalkDependencies), + ) + + @classmethod + def target_types(cls): + return [SmalltalkLibrary] + + def test_normal_resolution(self) -> None: + self.add_to_build_file("src/smalltalk", "smalltalk()") + addr = Address.parse("src/smalltalk") + deps_field = Dependencies(["//:dep1", "//:dep2", ":sibling"], address=addr) + assert self.request_single_product( + Addresses, Params(DependenciesRequest(deps_field), create_options_bootstrapper()) + ) == Addresses( + [ + Address.parse("//:dep1"), + Address.parse("//:dep2"), + Address.parse("src/smalltalk:sibling"), + ] + ) + + # Also test that we handle no dependencies. + empty_deps_field = Dependencies(None, address=addr) + assert self.request_single_product( + Addresses, Params(DependenciesRequest(empty_deps_field), create_options_bootstrapper()) + ) == Addresses([]) + + def test_dependency_injection(self) -> None: + self.add_to_build_file("", "smalltalk(name='target')") + + def assert_injected(deps_cls: Type[Dependencies], *, injected: List[str]) -> None: + deps_field = deps_cls(["//:provided"], address=Address.parse("//:target")) + result = self.request_single_product( + Addresses, Params(DependenciesRequest(deps_field), create_options_bootstrapper()) + ) + assert result == Addresses( + sorted(Address.parse(addr) for addr in (*injected, "//:provided")) + ) + + assert_injected(Dependencies, injected=[]) + assert_injected(SmalltalkDependencies, injected=["//:injected"]) + assert_injected(CustomSmalltalkDependencies, injected=["//:custom_injected", "//:injected"]) + + def test_dependency_inference(self) -> None: + self.add_to_build_file( + "", + dedent( + """\ + smalltalk(name='inferred1') + smalltalk(name='inferred2') + smalltalk(name='inferred3') + smalltalk(name='provided') + """ + ), + ) + self.create_file("demo/f1.st", "//:inferred1\n//:inferred2\n") + self.create_file("demo/f2.st", "//:inferred3\n") + self.add_to_build_file("demo", "smalltalk(sources=['*.st'], dependencies=['//:provided'])") + + deps_field = Dependencies(["//:provided"], address=Address.parse("demo")) + result = self.request_single_product( + Addresses, + Params( + DependenciesRequest(deps_field), + create_options_bootstrapper(args=["--dependency-inference"]), + ), + ) + assert result == Addresses( + sorted( + Address.parse(addr) + for addr in ["//:inferred1", "//:inferred2", "//:inferred3", "//:provided"] + ) + ) diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index d6ba0b67dc2..18b63407b61 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -26,19 +26,10 @@ from typing_extensions import final from pants.base.specs import OriginSpec -from pants.engine.addresses import Address, Addresses, assert_single_address +from pants.engine.addresses import Address, assert_single_address from pants.engine.collection import Collection, DeduplicatedCollection -from pants.engine.fs import ( - EMPTY_SNAPSHOT, - GlobExpansionConjunction, - GlobMatchErrorBehavior, - PathGlobs, - Snapshot, -) -from pants.engine.rules import RootRule, rule -from pants.engine.selectors import Get, MultiGet +from pants.engine.fs import Snapshot from pants.engine.unions import UnionMembership, union -from pants.option.global_options import GlobalOptions from pants.source.filespec import Filespec from pants.util.collections import ensure_list, ensure_str_list from pants.util.frozendict import FrozenDict @@ -806,49 +797,6 @@ def __init__( self.expect_single_field_set = expect_single_field_set -@rule -def find_valid_field_sets( - request: TargetsToValidFieldSetsRequest, - targets_with_origins: TargetsWithOrigins, - union_membership: UnionMembership, - registered_target_types: RegisteredTargetTypes, -) -> TargetsToValidFieldSets: - field_set_types: Iterable[ - Union[Type[FieldSet], Type[FieldSetWithOrigin]] - ] = union_membership.union_rules[request.field_set_superclass] - targets_to_valid_field_sets = {} - for tgt_with_origin in targets_with_origins: - valid_field_sets = [ - ( - field_set_type.create(tgt_with_origin) - if issubclass(field_set_type, FieldSetWithOrigin) - else field_set_type.create(tgt_with_origin.target) - ) - for field_set_type in field_set_types - if field_set_type.is_valid(tgt_with_origin.target) - ] - if valid_field_sets: - targets_to_valid_field_sets[tgt_with_origin] = valid_field_sets - if request.error_if_no_valid_targets and not targets_to_valid_field_sets: - raise NoValidTargetsException.create_from_field_sets( - targets_with_origins, - field_set_types=field_set_types, - goal_description=request.goal_description, - union_membership=union_membership, - registered_target_types=registered_target_types, - ) - result = TargetsToValidFieldSets(targets_to_valid_field_sets) - if not request.expect_single_field_set: - return result - if len(result.targets) > 1: - raise TooManyTargetsException(result.targets, goal_description=request.goal_description) - if len(result.field_sets) > 1: - raise AmbiguousImplementationsException( - result.targets[0], result.field_sets, goal_description=request.goal_description - ) - return result - - # ----------------------------------------------------------------------------------------------- # Exception messages # ----------------------------------------------------------------------------------------------- @@ -914,151 +862,6 @@ def __init__( ) -# NB: This has a tight coupling to goals. Feel free to change this if necessary. -class NoValidTargetsException(Exception): - def __init__( - self, - targets_with_origins: TargetsWithOrigins, - *, - valid_target_types: Iterable[Type[Target]], - goal_description: str, - ) -> None: - valid_target_aliases = sorted({target_type.alias for target_type in valid_target_types}) - invalid_target_aliases = sorted({tgt.alias for tgt in targets_with_origins.targets}) - specs = sorted( - { - target_with_origin.origin.to_spec_string() - for target_with_origin in targets_with_origins - } - ) - bulleted_list_sep = "\n * " - super().__init__( - f"{goal_description.capitalize()} only works with the following target types:" - f"{bulleted_list_sep}{bulleted_list_sep.join(valid_target_aliases)}\n\n" - f"You specified `{' '.join(specs)}`, which only included the following target types:" - f"{bulleted_list_sep}{bulleted_list_sep.join(invalid_target_aliases)}" - ) - - @classmethod - def create_from_field_sets( - cls, - targets_with_origins: TargetsWithOrigins, - *, - field_set_types: Iterable[Type[_AbstractFieldSet]], - goal_description: str, - union_membership: UnionMembership, - registered_target_types: RegisteredTargetTypes, - ) -> "NoValidTargetsException": - valid_target_types = { - target_type - for field_set_type in field_set_types - for target_type in field_set_type.valid_target_types( - registered_target_types.types, union_membership=union_membership - ) - } - return cls( - targets_with_origins, - valid_target_types=valid_target_types, - goal_description=goal_description, - ) - - -# NB: This has a tight coupling to goals. Feel free to change this if necessary. -class TooManyTargetsException(Exception): - def __init__(self, targets: Iterable[Target], *, goal_description: str) -> None: - bulleted_list_sep = "\n * " - addresses = sorted(tgt.address.spec for tgt in targets) - super().__init__( - f"{goal_description.capitalize()} only works with one valid target, but was given " - f"multiple valid targets:{bulleted_list_sep}{bulleted_list_sep.join(addresses)}\n\n" - "Please select one of these targets to run." - ) - - -# NB: This has a tight coupling to goals. Feel free to change this if necessary. -class AmbiguousImplementationsException(Exception): - """Exception for when a single target has multiple valid FieldSets, but the goal only expects - there to be one FieldSet.""" - - def __init__( - self, target: Target, field_sets: Iterable[_AbstractFieldSet], *, goal_description: str, - ) -> None: - # TODO: improve this error message. A better error message would explain to users how they - # can resolve the issue. - possible_field_sets_types = sorted(field_set.__class__.__name__ for field_set in field_sets) - bulleted_list_sep = "\n * " - super().__init__( - f"Multiple of the registered implementations for {goal_description} work for " - f"{target.address} (target type {repr(target.alias)}). It is ambiguous which " - "implementation to use.\n\nPossible implementations:" - f"{bulleted_list_sep}{bulleted_list_sep.join(possible_field_sets_types)}" - ) - - -class AmbiguousCodegenImplementationsException(Exception): - """Exception for when there are multiple codegen implementations and it is ambiguous which to - use.""" - - def __init__( - self, - generators: Iterable[Type["GenerateSourcesRequest"]], - *, - for_sources_types: Iterable[Type["Sources"]], - ) -> None: - bulleted_list_sep = "\n * " - all_same_generator_paths = ( - len(set((generator.input, generator.output) for generator in generators)) == 1 - ) - example_generator = list(generators)[0] - input = example_generator.input.__name__ - if all_same_generator_paths: - output = example_generator.output.__name__ - possible_generators = sorted(generator.__name__ for generator in generators) - super().__init__( - f"Multiple of the registered code generators can generate {output} from {input}. " - "It is ambiguous which implementation to use.\n\nPossible implementations:" - f"{bulleted_list_sep}{bulleted_list_sep.join(possible_generators)}" - ) - else: - possible_output_types = sorted( - generator.output.__name__ - for generator in generators - if issubclass(generator.output, tuple(for_sources_types)) - ) - possible_generators_with_output = [ - f"{generator.__name__} -> {generator.output.__name__}" - for generator in sorted(generators, key=lambda generator: generator.output.__name__) - ] - super().__init__( - f"Multiple of the registered code generators can generate one of " - f"{possible_output_types} from {input}. It is ambiguous which implementation to " - f"use. This can happen when the call site requests too many different output types " - f"from the same original protocol sources.\n\nPossible implementations with their " - f"output type: {bulleted_list_sep}" - f"{bulleted_list_sep.join(possible_generators_with_output)}" - ) - - -class AmbiguousDependencyInferenceException(Exception): - """Exception for when there are multiple dependency inference implementations and it is - ambiguous which to use.""" - - def __init__( - self, - implementations: Iterable[Type["InferDependenciesRequest"]], - *, - from_sources_type: Type["Sources"], - ) -> None: - bulleted_list_sep = "\n * " - possible_implementations = sorted(impl.__name__ for impl in implementations) - super().__init__( - f"Multiple of the registered dependency inference implementations can infer " - f"dependencies from {from_sources_type.__name__}. It is ambiguous which " - "implementation to use.\n\nPossible implementations:" - f"{bulleted_list_sep}{bulleted_list_sep.join(possible_implementations)}" - ) - - # ----------------------------------------------------------------------------------------------- # Field templates # ----------------------------------------------------------------------------------------------- @@ -1538,95 +1341,6 @@ class GeneratedSources: snapshot: Snapshot -@rule -async def hydrate_sources( - request: HydrateSourcesRequest, - glob_match_error_behavior: GlobMatchErrorBehavior, - union_membership: UnionMembership, -) -> HydratedSources: - sources_field = request.field - - # First, find if there are any code generators for the input `sources_field`. This will be used - # to determine if the sources_field is valid or not. - # We could alternatively use `sources_field.can_generate()`, but we want to error if there are - # 2+ generators due to ambiguity. - generate_request_types = union_membership.get(GenerateSourcesRequest) - relevant_generate_request_types = [ - generate_request_type - for generate_request_type in generate_request_types - if isinstance(sources_field, generate_request_type.input) - and issubclass(generate_request_type.output, request.for_sources_types) - ] - if request.enable_codegen and len(relevant_generate_request_types) > 1: - raise AmbiguousCodegenImplementationsException( - relevant_generate_request_types, for_sources_types=request.for_sources_types - ) - generate_request_type = next(iter(relevant_generate_request_types), None) - - # Now, determine if any of the `for_sources_types` may be used, either because the - # sources_field is a direct subclass or can be generated into one of the valid types. - def compatible_with_sources_field(valid_type: Type[Sources]) -> bool: - is_instance = isinstance(sources_field, valid_type) - can_be_generated = ( - request.enable_codegen - and generate_request_type is not None - and issubclass(generate_request_type.output, valid_type) - ) - return is_instance or can_be_generated - - sources_type = next( - ( - valid_type - for valid_type in request.for_sources_types - if compatible_with_sources_field(valid_type) - ), - None, - ) - if sources_type is None: - return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=None) - - # Now, hydrate the `globs`. Even if we are going to use codegen, we will need the original - # protocol sources to be hydrated. - globs = sources_field.sanitized_raw_value - if globs is None: - return HydratedSources(EMPTY_SNAPSHOT, sources_field.filespec, sources_type=sources_type) - - conjunction = ( - GlobExpansionConjunction.all_match - if not sources_field.default or (set(globs) != set(sources_field.default)) - else GlobExpansionConjunction.any_match - ) - snapshot = await Get( - Snapshot, - PathGlobs( - (sources_field.prefix_glob_with_address(glob) for glob in globs), - conjunction=conjunction, - glob_match_error_behavior=glob_match_error_behavior, - # TODO(#9012): add line number referring to the sources field. When doing this, we'll - # likely need to `await Get(BuildFileAddress](Address)`. - description_of_origin=( - f"{sources_field.address}'s `{sources_field.alias}` field" - if glob_match_error_behavior != GlobMatchErrorBehavior.ignore - else None - ), - ), - ) - sources_field.validate_snapshot(snapshot) - - # Finally, return if codegen is not in use; otherwise, run the relevant code generator. - if not request.enable_codegen or generate_request_type is None: - return HydratedSources(snapshot, sources_field.filespec, sources_type=sources_type) - wrapped_protocol_target = await Get(WrappedTarget, Address, sources_field.address) - generated_sources = await Get( - GeneratedSources, - GenerateSourcesRequest, - generate_request_type(snapshot, wrapped_protocol_target.target), - ) - return HydratedSources( - generated_sources.snapshot, sources_field.filespec, sources_type=sources_type - ) - - # ----------------------------------------------------------------------------------------------- # `Dependencies` field # ----------------------------------------------------------------------------------------------- @@ -1649,6 +1363,15 @@ def sanitize_raw_value( value_or_default = super().sanitize_raw_value(raw_value, address=address) if value_or_default is None: return None + try: + ensure_str_list(value_or_default) + except ValueError: + raise InvalidFieldTypeException( + address, + cls.alias, + value_or_default, + expected_type="an iterable of strings (e.g. a list of strings)", + ) return tuple(sorted(value_or_default)) @@ -1736,56 +1459,6 @@ class InferredDependencies(DeduplicatedCollection[Address]): sort_input = True -@rule -async def resolve_dependencies( - request: DependenciesRequest, union_membership: UnionMembership, global_options: GlobalOptions -) -> Addresses: - provided = [ - Address.parse( - dep, - relative_to=request.field.address.spec_path, - subproject_roots=global_options.options.subproject_roots, - ) - for dep in request.field.sanitized_raw_value or () - ] - - # Inject any dependencies. This is determined by the `request.field` class. For example, if - # there is a rule to inject for FortranDependencies, then FortranDependencies and any subclass - # of FortranDependencies will use that rule. - inject_request_types = union_membership.get(InjectDependenciesRequest) - injected = await MultiGet( - Get(InjectedDependencies, InjectDependenciesRequest, inject_request_type(request.field)) - for inject_request_type in inject_request_types - if isinstance(request.field, inject_request_type.inject_for) - ) - - inference_request_types = union_membership.get(InferDependenciesRequest) - inferred = InferredDependencies() - if global_options.options.dependency_inference and inference_request_types: - # Dependency inference is solely determined by the `Sources` field for a Target, so we - # re-resolve the original target to inspect its `Sources` field, if any. - wrapped_tgt = await Get(WrappedTarget, Address, request.field.address) - sources_field = wrapped_tgt.target.get(Sources) - relevant_inference_request_types = [ - inference_request_type - for inference_request_type in inference_request_types - if isinstance(sources_field, inference_request_type.infer_from) - ] - if relevant_inference_request_types: - if len(relevant_inference_request_types) > 1: - raise AmbiguousDependencyInferenceException( - relevant_inference_request_types, from_sources_type=type(sources_field) - ) - inference_request_type = relevant_inference_request_types[0] - inferred = await Get( - InferredDependencies, - InferDependenciesRequest, - inference_request_type(sources_field), - ) - - return Addresses(sorted([*provided, *itertools.chain.from_iterable(injected), *inferred])) - - # ----------------------------------------------------------------------------------------------- # Other common Fields used across most targets # ----------------------------------------------------------------------------------------------- @@ -1821,16 +1494,3 @@ class ProvidesField(PrimitiveField): alias = "provides" default: ClassVar[Optional[Any]] = None - - -def rules(): - return [ - find_valid_field_sets, - hydrate_sources, - resolve_dependencies, - RootRule(TargetsToValidFieldSetsRequest), - RootRule(HydrateSourcesRequest), - RootRule(DependenciesRequest), - RootRule(InjectDependenciesRequest), - RootRule(InferDependenciesRequest), - ] diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index d1018b645a4..ee99d8d02b2 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -1,53 +1,30 @@ # Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). -import itertools from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from textwrap import dedent -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple import pytest from typing_extensions import final from pants.base.specs import FilesystemLiteralSpec -from pants.engine.addresses import Address, Addresses -from pants.engine.fs import ( - EMPTY_DIGEST, - CreateDigest, - Digest, - DigestContents, - FileContent, - PathGlobs, - Snapshot, -) -from pants.engine.internals.scheduler import ExecutionError -from pants.engine.rules import RootRule, rule -from pants.engine.selectors import Get, Params +from pants.engine.addresses import Address +from pants.engine.fs import EMPTY_DIGEST, PathGlobs, Snapshot +from pants.engine.rules import rule +from pants.engine.selectors import Get from pants.engine.target import ( - AmbiguousCodegenImplementationsException, - AmbiguousImplementationsException, AsyncField, BoolField, Dependencies, - DependenciesRequest, DictStringToStringField, DictStringToStringSequenceField, FieldSet, FieldSetWithOrigin, - GeneratedSources, - GenerateSourcesRequest, - HydratedSources, - HydrateSourcesRequest, - InferDependenciesRequest, - InferredDependencies, - InjectDependenciesRequest, - InjectedDependencies, InvalidFieldChoiceException, InvalidFieldException, InvalidFieldTypeException, - NoValidTargetsException, PrimitiveField, RequiredFieldMissingException, ScalarField, @@ -58,19 +35,12 @@ StringSequenceField, Tags, Target, - TargetsToValidFieldSets, - TargetsToValidFieldSetsRequest, - TargetsWithOrigins, TargetWithOrigin, - TooManyTargetsException, - WrappedTarget, generate_subtarget, generate_subtarget_address, ) -from pants.engine.unions import UnionMembership, UnionRule, union +from pants.engine.unions import UnionMembership from pants.testutil.engine.util import MockGet, run_rule -from pants.testutil.option.util import create_options_bootstrapper -from pants.testutil.test_base import TestBase from pants.util.collections import ensure_str_list from pants.util.frozendict import FrozenDict from pants.util.ordered_set import OrderedSet @@ -489,7 +459,7 @@ class NoSourcesTgt(Target): # ----------------------------------------------------------------------------------------------- -# Test FieldSet +# Test FieldSet. Also see engine/internals/graph_test.py. # ----------------------------------------------------------------------------------------------- @@ -552,122 +522,6 @@ class UnrelatedFieldSet(FieldSetWithOrigin): ) -class TestFindValidFieldSets(TestBase): - class InvalidTarget(Target): - alias = "invalid_target" - core_fields = () - - @classmethod - def target_types(cls): - return [FortranTarget, cls.InvalidTarget] - - @union - class FieldSetSuperclass(FieldSet): - pass - - @dataclass(frozen=True) - class FieldSetSubclass1(FieldSetSuperclass): - required_fields = (FortranSources,) - - sources: FortranSources - - @dataclass(frozen=True) - class FieldSetSubclass2(FieldSetSuperclass): - required_fields = (FortranSources,) - - sources: FortranSources - - @union - class FieldSetSuperclassWithOrigin(FieldSetWithOrigin): - pass - - class FieldSetSubclassWithOrigin(FieldSetSuperclassWithOrigin): - required_fields = (FortranSources,) - - sources: FortranSources - - @classmethod - def rules(cls): - return ( - *super().rules(), - RootRule(TargetsWithOrigins), - UnionRule(cls.FieldSetSuperclass, cls.FieldSetSubclass1), - UnionRule(cls.FieldSetSuperclass, cls.FieldSetSubclass2), - UnionRule(cls.FieldSetSuperclassWithOrigin, cls.FieldSetSubclassWithOrigin), - ) - - def test_find_valid_field_sets(self) -> None: - origin = FilesystemLiteralSpec("f.txt") - valid_tgt = FortranTarget({}, address=Address.parse(":valid")) - valid_tgt_with_origin = TargetWithOrigin(valid_tgt, origin) - invalid_tgt = self.InvalidTarget({}, address=Address.parse(":invalid")) - invalid_tgt_with_origin = TargetWithOrigin(invalid_tgt, origin) - - def find_valid_field_sets( - superclass: Type, - targets_with_origins: Iterable[TargetWithOrigin], - *, - error_if_no_valid_targets: bool = False, - expect_single_config: bool = False, - ) -> TargetsToValidFieldSets: - request = TargetsToValidFieldSetsRequest( - superclass, - goal_description="fake", - error_if_no_valid_targets=error_if_no_valid_targets, - expect_single_field_set=expect_single_config, - ) - return self.request_single_product( - TargetsToValidFieldSets, Params(request, TargetsWithOrigins(targets_with_origins),), - ) - - valid = find_valid_field_sets( - self.FieldSetSuperclass, [valid_tgt_with_origin, invalid_tgt_with_origin] - ) - assert valid.targets == (valid_tgt,) - assert valid.targets_with_origins == (valid_tgt_with_origin,) - assert valid.field_sets == ( - self.FieldSetSubclass1.create(valid_tgt), - self.FieldSetSubclass2.create(valid_tgt), - ) - - with pytest.raises(ExecutionError) as exc: - find_valid_field_sets( - self.FieldSetSuperclass, [valid_tgt_with_origin], expect_single_config=True - ) - assert AmbiguousImplementationsException.__name__ in str(exc.value) - - with pytest.raises(ExecutionError) as exc: - find_valid_field_sets( - self.FieldSetSuperclass, - [ - valid_tgt_with_origin, - TargetWithOrigin(FortranTarget({}, address=Address.parse(":valid2")), origin), - ], - expect_single_config=True, - ) - assert TooManyTargetsException.__name__ in str(exc.value) - - no_valid_targets = find_valid_field_sets(self.FieldSetSuperclass, [invalid_tgt_with_origin]) - assert no_valid_targets.targets == () - assert no_valid_targets.targets_with_origins == () - assert no_valid_targets.field_sets == () - - with pytest.raises(ExecutionError) as exc: - find_valid_field_sets( - self.FieldSetSuperclass, [invalid_tgt_with_origin], error_if_no_valid_targets=True - ) - assert NoValidTargetsException.__name__ in str(exc.value) - - valid_with_origin = find_valid_field_sets( - self.FieldSetSuperclassWithOrigin, [valid_tgt_with_origin, invalid_tgt_with_origin] - ) - assert valid_with_origin.targets == (valid_tgt,) - assert valid_with_origin.targets_with_origins == (valid_tgt_with_origin,) - assert valid_with_origin.field_sets == ( - self.FieldSetSubclassWithOrigin.create(valid_tgt_with_origin), - ) - - # ----------------------------------------------------------------------------------------------- # Test Field templates # ----------------------------------------------------------------------------------------------- @@ -831,446 +685,30 @@ def assert_invalid_type(raw_value: Any) -> None: # ----------------------------------------------------------------------------------------------- -# Test Sources -# ----------------------------------------------------------------------------------------------- - - -class TestSources(TestBase): - @classmethod - def rules(cls): - return (*super().rules(), RootRule(HydrateSourcesRequest)) - - def test_raw_value_sanitation(self) -> None: - addr = Address.parse(":test") - - def assert_flexible_constructor(raw_value: Iterable[str]) -> None: - assert Sources(raw_value, address=addr).sanitized_raw_value == tuple(raw_value) - - for v in [("f1.txt", "f2.txt"), ["f1.txt", "f2.txt"], OrderedSet(["f1.txt", "f2.txt"])]: - assert_flexible_constructor(v) - - def assert_invalid_type(raw_value: Any) -> None: - with pytest.raises(InvalidFieldTypeException): - Sources(raw_value, address=addr) - - for v in [0, object(), "f1.txt"]: # type: ignore[assignment] - assert_invalid_type(v) - - def test_normal_hydration(self) -> None: - addr = Address.parse("src/fortran:lib") - self.create_files("src/fortran", files=["f1.f95", "f2.f95", "f1.f03", "ignored.f03"]) - sources = Sources(["f1.f95", "*.f03", "!ignored.f03", "!**/ignore*"], address=addr) - hydrated_sources = self.request_single_product( - HydratedSources, HydrateSourcesRequest(sources) - ) - assert hydrated_sources.snapshot.files == ("src/fortran/f1.f03", "src/fortran/f1.f95") - - # Also test that the Filespec is correct. This does not need hydration to be calculated. - assert ( - sources.filespec - == { - "includes": ["src/fortran/*.f03", "src/fortran/f1.f95"], - "excludes": ["src/fortran/**/ignore*", "src/fortran/ignored.f03"], - } - == hydrated_sources.filespec - ) - - def test_output_type(self) -> None: - class SourcesSubclass(Sources): - pass - - addr = Address.parse(":lib") - self.create_files("", files=["f1.f95"]) - - valid_sources = SourcesSubclass(["*"], address=addr) - hydrated_valid_sources = self.request_single_product( - HydratedSources, - HydrateSourcesRequest(valid_sources, for_sources_types=[SourcesSubclass]), - ) - assert hydrated_valid_sources.snapshot.files == ("f1.f95",) - assert hydrated_valid_sources.sources_type == SourcesSubclass - - invalid_sources = Sources(["*"], address=addr) - hydrated_invalid_sources = self.request_single_product( - HydratedSources, - HydrateSourcesRequest(invalid_sources, for_sources_types=[SourcesSubclass]), - ) - assert hydrated_invalid_sources.snapshot.files == () - assert hydrated_invalid_sources.sources_type is None - - def test_unmatched_globs(self) -> None: - self.create_files("", files=["f1.f95"]) - sources = Sources(["non_existent.f95"], address=Address.parse(":lib")) - with pytest.raises(ExecutionError) as exc: - self.request_single_product(HydratedSources, HydrateSourcesRequest(sources)) - assert "Unmatched glob" in str(exc.value) - assert "//:lib" in str(exc.value) - assert "non_existent.f95" in str(exc.value) - - def test_default_globs(self) -> None: - class DefaultSources(Sources): - default = ("default.f95", "default.f03", "*.f08", "!ignored.f08") - - addr = Address.parse("src/fortran:lib") - # NB: Not all globs will be matched with these files, specifically `default.f03` will not - # be matched. This is intentional to ensure that we use `any` glob conjunction rather - # than the normal `all` conjunction. - self.create_files("src/fortran", files=["default.f95", "f1.f08", "ignored.f08"]) - sources = DefaultSources(None, address=addr) - assert set(sources.sanitized_raw_value or ()) == set(DefaultSources.default) - - hydrated_sources = self.request_single_product( - HydratedSources, HydrateSourcesRequest(sources) - ) - assert hydrated_sources.snapshot.files == ("src/fortran/default.f95", "src/fortran/f1.f08") - - def test_expected_file_extensions(self) -> None: - class ExpectedExtensionsSources(Sources): - expected_file_extensions = (".f95", ".f03") - - addr = Address.parse("src/fortran:lib") - self.create_files("src/fortran", files=["s.f95", "s.f03", "s.f08"]) - sources = ExpectedExtensionsSources(["s.f*"], address=addr) - with pytest.raises(ExecutionError) as exc: - self.request_single_product(HydratedSources, HydrateSourcesRequest(sources)) - assert "s.f08" in str(exc.value) - assert str(addr) in str(exc.value) - - # Also check that we support valid sources - valid_sources = ExpectedExtensionsSources(["s.f95"], address=addr) - assert self.request_single_product( - HydratedSources, HydrateSourcesRequest(valid_sources) - ).snapshot.files == ("src/fortran/s.f95",) - - def test_expected_num_files(self) -> None: - class ExpectedNumber(Sources): - expected_num_files = 2 - - class ExpectedRange(Sources): - # We allow for 1 or 3 files - expected_num_files = range(1, 4, 2) - - self.create_files("", files=["f1.txt", "f2.txt", "f3.txt", "f4.txt"]) - - def hydrate(sources_cls: Type[Sources], sources: Iterable[str]) -> HydratedSources: - return self.request_single_product( - HydratedSources, - HydrateSourcesRequest(sources_cls(sources, address=Address.parse(":example"))), - ) - - with pytest.raises(ExecutionError) as exc: - hydrate(ExpectedNumber, []) - assert "must have 2 files" in str(exc.value) - with pytest.raises(ExecutionError) as exc: - hydrate(ExpectedRange, ["f1.txt", "f2.txt"]) - assert "must have 1 or 3 files" in str(exc.value) - - # Also check that we support valid # files. - assert hydrate(ExpectedNumber, ["f1.txt", "f2.txt"]).snapshot.files == ("f1.txt", "f2.txt") - assert hydrate(ExpectedRange, ["f1.txt"]).snapshot.files == ("f1.txt",) - assert hydrate(ExpectedRange, ["f1.txt", "f2.txt", "f3.txt"]).snapshot.files == ( - "f1.txt", - "f2.txt", - "f3.txt", - ) - - -# ----------------------------------------------------------------------------------------------- -# Test Codegen -# ----------------------------------------------------------------------------------------------- - - -class SmalltalkSources(Sources): - pass - - -class AvroSources(Sources): - pass - - -class AvroLibrary(Target): - alias = "avro_library" - core_fields = (AvroSources,) - - -class GenerateSmalltalkFromAvroRequest(GenerateSourcesRequest): - input = AvroSources - output = SmalltalkSources - - -@rule -async def generate_smalltalk_from_avro( - request: GenerateSmalltalkFromAvroRequest, -) -> GeneratedSources: - protocol_files = request.protocol_sources.files - - def generate_fortran(fp: str) -> FileContent: - parent = str(PurePath(fp).parent).replace("src/avro", "src/smalltalk") - file_name = f"{PurePath(fp).stem}.st" - return FileContent(str(PurePath(parent, file_name)), b"Generated") - - result = await Get(Snapshot, CreateDigest([generate_fortran(fp) for fp in protocol_files])) - return GeneratedSources(result) - - -class TestCodegen(TestBase): - @classmethod - def rules(cls): - return ( - *super().rules(), - generate_smalltalk_from_avro, - RootRule(GenerateSmalltalkFromAvroRequest), - RootRule(HydrateSourcesRequest), - UnionRule(GenerateSourcesRequest, GenerateSmalltalkFromAvroRequest), - ) - - @classmethod - def target_types(cls): - return [AvroLibrary] - - def setUp(self) -> None: - self.address = Address.parse("src/avro:lib") - self.create_files("src/avro", files=["f.avro"]) - self.add_to_build_file("src/avro", "avro_library(name='lib', sources=['*.avro'])") - self.union_membership = self.request_single_product(UnionMembership, Params()) - - def test_generate_sources(self) -> None: - protocol_sources = AvroSources(["*.avro"], address=self.address) - assert protocol_sources.can_generate(SmalltalkSources, self.union_membership) is True - - # First, get the original protocol sources. - hydrated_protocol_sources = self.request_single_product( - HydratedSources, HydrateSourcesRequest(protocol_sources) - ) - assert hydrated_protocol_sources.snapshot.files == ("src/avro/f.avro",) - - # Test directly feeding the protocol sources into the codegen rule. - wrapped_tgt = self.request_single_product(WrappedTarget, self.address) - generated_sources = self.request_single_product( - GeneratedSources, - GenerateSmalltalkFromAvroRequest( - hydrated_protocol_sources.snapshot, wrapped_tgt.target - ), - ) - assert generated_sources.snapshot.files == ("src/smalltalk/f.st",) - - # Test that HydrateSourcesRequest can also be used. - generated_via_hydrate_sources = self.request_single_product( - HydratedSources, - HydrateSourcesRequest( - protocol_sources, for_sources_types=[SmalltalkSources], enable_codegen=True - ), - ) - assert generated_via_hydrate_sources.snapshot.files == ("src/smalltalk/f.st",) - assert generated_via_hydrate_sources.sources_type == SmalltalkSources - - def test_works_with_subclass_fields(self) -> None: - class CustomAvroSources(AvroSources): - pass - - protocol_sources = CustomAvroSources(["*.avro"], address=self.address) - assert protocol_sources.can_generate(SmalltalkSources, self.union_membership) is True - generated = self.request_single_product( - HydratedSources, - HydrateSourcesRequest( - protocol_sources, for_sources_types=[SmalltalkSources], enable_codegen=True - ), - ) - assert generated.snapshot.files == ("src/smalltalk/f.st",) - - def test_cannot_generate_language(self) -> None: - class AdaSources(Sources): - pass - - protocol_sources = AvroSources(["*.avro"], address=self.address) - assert protocol_sources.can_generate(AdaSources, self.union_membership) is False - generated = self.request_single_product( - HydratedSources, - HydrateSourcesRequest( - protocol_sources, for_sources_types=[AdaSources], enable_codegen=True - ), - ) - assert generated.snapshot.files == () - assert generated.sources_type is None - - def test_ambiguous_implementations_exception(self) -> None: - # This error message is quite complex. We test that it correctly generates the message. - class SmalltalkGenerator1(GenerateSourcesRequest): - input = AvroSources - output = SmalltalkSources - - class SmalltalkGenerator2(GenerateSourcesRequest): - input = AvroSources - output = SmalltalkSources - - class AdaSources(Sources): - pass - - class AdaGenerator(GenerateSourcesRequest): - input = AvroSources - output = AdaSources - - class IrrelevantSources(Sources): - pass - - # Test when all generators have the same input and output. - exc = AmbiguousCodegenImplementationsException( - [SmalltalkGenerator1, SmalltalkGenerator2], for_sources_types=[SmalltalkSources] - ) - assert "can generate SmalltalkSources from AvroSources" in str(exc) - assert "* SmalltalkGenerator1" in str(exc) - assert "* SmalltalkGenerator2" in str(exc) - - # Test when the generators have different input and output, which usually happens because - # the call site used too expansive of a `for_sources_types` argument. - exc = AmbiguousCodegenImplementationsException( - [SmalltalkGenerator1, AdaGenerator], - for_sources_types=[SmalltalkSources, AdaSources, IrrelevantSources], - ) - assert "can generate one of ['AdaSources', 'SmalltalkSources'] from AvroSources" in str(exc) - assert "IrrelevantSources" not in str(exc) - assert "* SmalltalkGenerator1 -> SmalltalkSources" in str(exc) - assert "* AdaGenerator -> AdaSources" in str(exc) - - -# ----------------------------------------------------------------------------------------------- -# Test Dependencies +# Test Sources and Dependencies. Also see engine/internals/graph_test.py. # ----------------------------------------------------------------------------------------------- -class SmalltalkDependencies(Dependencies): - pass - - -class CustomSmalltalkDependencies(SmalltalkDependencies): - pass - - -class InjectSmalltalkDependencies(InjectDependenciesRequest): - inject_for = SmalltalkDependencies +def test_dependencies_and_sources_fields_raw_value_sanitation() -> None: + """Ensure that both Sources and Dependencies behave like a StringSequenceField does. + Normally, we would use StringSequenceField. However, these are both AsyncFields, and + StringSequenceField is a PrimitiveField, so we end up replicating that validation logic. + """ + addr = Address.parse(":test") -class InjectCustomSmalltalkDependencies(InjectDependenciesRequest): - inject_for = CustomSmalltalkDependencies + def assert_flexible_constructor(raw_value: Iterable[str]) -> None: + assert Sources(raw_value, address=addr).sanitized_raw_value == tuple(raw_value) + assert Dependencies(raw_value, address=addr).sanitized_raw_value == tuple(raw_value) + for v in [("f1.txt", "f2.txt"), ["f1.txt", "f2.txt"], OrderedSet(["f1.txt", "f2.txt"])]: + assert_flexible_constructor(v) -@rule -def inject_smalltalk_deps(_: InjectSmalltalkDependencies) -> InjectedDependencies: - return InjectedDependencies([Address.parse("//:injected")]) - - -@rule -def inject_custom_smalltalk_deps(_: InjectCustomSmalltalkDependencies) -> InjectedDependencies: - return InjectedDependencies([Address.parse("//:custom_injected")]) - - -# NB: We subclass to ensure that dependency inference works properly with subclasses. -class SmalltalkLibrarySources(SmalltalkSources): - pass - - -class SmalltalkLibrary(Target): - alias = "smalltalk" - core_fields = (Dependencies, SmalltalkLibrarySources) - - -class InferSmalltalkDependencies(InferDependenciesRequest): - infer_from = SmalltalkSources - - -@rule -async def infer_smalltalk_dependencies(request: InferSmalltalkDependencies) -> InferredDependencies: - # To demo an inference rule, we simply treat each `sources` file to contain a list of - # addresses, one per line. - hydrated_sources = await Get(HydratedSources, HydrateSourcesRequest(request.sources_field)) - digest_contents = await Get(DigestContents, Digest, hydrated_sources.snapshot.digest) - all_lines = itertools.chain.from_iterable( - file_content.content.decode().splitlines() for file_content in digest_contents - ) - return InferredDependencies(Address.parse(line) for line in all_lines) - - -class TestDependencies(TestBase): - @classmethod - def rules(cls): - return ( - *super().rules(), - RootRule(DependenciesRequest), - inject_smalltalk_deps, - inject_custom_smalltalk_deps, - infer_smalltalk_dependencies, - UnionRule(InjectDependenciesRequest, InjectSmalltalkDependencies), - UnionRule(InjectDependenciesRequest, InjectCustomSmalltalkDependencies), - UnionRule(InferDependenciesRequest, InferSmalltalkDependencies), - ) - - @classmethod - def target_types(cls): - return [SmalltalkLibrary] - - def test_normal_resolution(self) -> None: - self.add_to_build_file("src/smalltalk", "smalltalk()") - addr = Address.parse("src/smalltalk") - deps_field = Dependencies(["//:dep1", "//:dep2", ":sibling"], address=addr) - assert self.request_single_product( - Addresses, Params(DependenciesRequest(deps_field), create_options_bootstrapper()) - ) == Addresses( - [ - Address.parse("//:dep1"), - Address.parse("//:dep2"), - Address.parse("src/smalltalk:sibling"), - ] - ) - - # Also test that we handle no dependencies. - empty_deps_field = Dependencies(None, address=addr) - assert self.request_single_product( - Addresses, Params(DependenciesRequest(empty_deps_field), create_options_bootstrapper()) - ) == Addresses([]) - - def test_dependency_injection(self) -> None: - self.add_to_build_file("", "smalltalk(name='target')") - - def assert_injected(deps_cls: Type[Dependencies], *, injected: List[str]) -> None: - deps_field = deps_cls(["//:provided"], address=Address.parse("//:target")) - result = self.request_single_product( - Addresses, Params(DependenciesRequest(deps_field), create_options_bootstrapper()) - ) - assert result == Addresses( - sorted(Address.parse(addr) for addr in (*injected, "//:provided")) - ) + def assert_invalid_type(raw_value: Any) -> None: + with pytest.raises(InvalidFieldTypeException): + Sources(raw_value, address=addr) + with pytest.raises(InvalidFieldTypeException): + Dependencies(raw_value, address=addr) - assert_injected(Dependencies, injected=[]) - assert_injected(SmalltalkDependencies, injected=["//:injected"]) - assert_injected(CustomSmalltalkDependencies, injected=["//:custom_injected", "//:injected"]) - - def test_dependency_inference(self) -> None: - self.add_to_build_file( - "", - dedent( - """\ - smalltalk(name='inferred1') - smalltalk(name='inferred2') - smalltalk(name='inferred3') - smalltalk(name='provided') - """ - ), - ) - self.create_file("demo/f1.st", "//:inferred1\n//:inferred2\n") - self.create_file("demo/f2.st", "//:inferred3\n") - self.add_to_build_file("demo", "smalltalk(sources=['*.st'], dependencies=['//:provided'])") - - deps_field = Dependencies(["//:provided"], address=Address.parse("demo")) - result = self.request_single_product( - Addresses, - Params( - DependenciesRequest(deps_field), - create_options_bootstrapper(args=["--dependency-inference"]), - ), - ) - assert result == Addresses( - sorted( - Address.parse(addr) - for addr in ["//:inferred1", "//:inferred2", "//:inferred3", "//:provided"] - ) - ) + for v in [0, object(), "f1.txt"]: # type: ignore[assignment] + assert_invalid_type(v) diff --git a/src/python/pants/init/engine_initializer.py b/src/python/pants/init/engine_initializer.py index 7d6539bcdcc..84b3f008113 100644 --- a/src/python/pants/init/engine_initializer.py +++ b/src/python/pants/init/engine_initializer.py @@ -10,7 +10,7 @@ from pants.base.exiter import PANTS_SUCCEEDED_EXIT_CODE from pants.base.specs import Specs from pants.build_graph.build_configuration import BuildConfiguration -from pants.engine import interactive_process, process, target +from pants.engine import interactive_process, process from pants.engine.console import Console from pants.engine.fs import Workspace, create_fs_rules from pants.engine.goal import Goal @@ -297,7 +297,6 @@ def build_root_singleton() -> BuildRoot: *graph.rules(), *options_parsing.rules(), *process.rules(), - *target.rules(), *create_fs_rules(), *create_platform_rules(), *create_graph_rules(address_mapper),