Skip to content

Commit

Permalink
Add Target.class_has_field (#9329)
Browse files Browse the repository at this point in the history
### Problem

Currently, we have the instance method `Target.has_field()` (and `Target.has_fields()`), e.g. `tgt.has_field(PythonSources)`.

In some cases, though, we need a _classmethod_ rather than an _instance method_ to be able to do this same type of check. For example, when generating an error message when running `./v2 binary` on an invalid target type, we want to be able to calculate every single target type that _does_ work with the goal. To do that, we need to be able to iterate over every `Type[Target]` registered and call `PythonLibrary.has_fields([EntryPoint, PythonBinarySources])` so that we can decide if `PythonLibrary` is a valid target type or not.

The tricky part is that we must support plugin fields added to pre-existing target types, which are achieved by plugin authors registering `UnionRule(PythonLibrary.PluginField, MyCustomField)`.

### Solution

Add `Target.class_has_field()` and `Target.class_has_fields()`.

Factor out `Target.has_field()` to deduplicate between the classmethod and instance method.
  • Loading branch information
Eric-Arellano committed Mar 18, 2020
1 parent 5ccb04d commit 250ed9d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 18 deletions.
72 changes: 54 additions & 18 deletions src/python/pants/engine/target.py
Expand Up @@ -245,17 +245,13 @@ def __init__(
unhydrated_values: Dict[str, Any],
*,
address: Address,
# NB: `union_membership` is only optional to facilitate tests. In production, we should
# always provide this parameter. This should be safe to do because production code should
# rarely directly instantiate Targets and should instead use the engine to request them.
union_membership: Optional[UnionMembership] = None,
) -> None:
self.address = address
self.plugin_fields = cast(
Tuple[Type[Field], ...],
(
()
if union_membership is None
else tuple(union_membership.union_rules.get(self.PluginField, ()))
),
)
self.plugin_fields = self._find_plugin_fields(union_membership or UnionMembership({}))

self.field_values = {}
aliases_to_field_types = {field_type.alias: field_type for field_type in self.field_types}
Expand Down Expand Up @@ -304,7 +300,17 @@ def __str__(self) -> str:
return f"{self.alias}({address}{fields})"

@final
def _find_registered_field_subclass(self, requested_field: Type[_F]) -> Optional[Type[_F]]:
@classmethod
def _find_plugin_fields(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
return cast(
Tuple[Type[Field], ...], tuple(union_membership.union_rules.get(cls.PluginField, ()))
)

@final
@classmethod
def _find_registered_field_subclass(
cls, requested_field: Type[_F], *, registered_fields: Iterable[Type[Field]]
) -> Optional[Type[_F]]:
"""Check if the Target has registered a subclass of the requested Field.
This is necessary to allow targets to override the functionality of common fields like
Expand All @@ -315,7 +321,7 @@ def _find_registered_field_subclass(self, requested_field: Type[_F]) -> Optional
subclass = next(
(
registered_field
for registered_field in self.field_types
for registered_field in registered_fields
if issubclass(registered_field, requested_field)
),
None,
Expand All @@ -342,7 +348,9 @@ def get(self, field: Type[_F]) -> _F:
result = self.field_values.get(field, None)
if result is not None:
return cast(_F, result)
field_subclass = self._find_registered_field_subclass(field)
field_subclass = self._find_registered_field_subclass(
field, registered_fields=self.field_types
)
if field_subclass is not None:
return cast(_F, self.field_values[field_subclass])
raise KeyError(
Expand All @@ -351,6 +359,22 @@ def get(self, field: Type[_F]) -> _F:
"filter out any irrelevant Targets."
)

@final
@classmethod
def _has_fields(
cls, fields: Iterable[Type[Field]], *, registered_fields: Iterable[Type[Field]]
) -> bool:
unrecognized_fields = [field for field in fields if field not in registered_fields]
if not unrecognized_fields:
return True
for unrecognized_field in unrecognized_fields:
maybe_subclass = cls._find_registered_field_subclass(
unrecognized_field, registered_fields=registered_fields
)
if maybe_subclass is None:
return False
return True

@final
def has_field(self, field: Type[Field]) -> bool:
"""Check that this target has registered the requested field.
Expand All @@ -369,13 +393,25 @@ def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
custom subclass `PythonSources`, both `python_tgt.has_fields([PythonSources])` and
`python_tgt.has_fields([Sources])` will return True.
"""
unrecognized_fields = [field for field in fields if field not in self.field_types]
if not unrecognized_fields:
return True
for unrecognized_field in unrecognized_fields:
if self._find_registered_field_subclass(unrecognized_field) is None:
return False
return True
return self._has_fields(fields, registered_fields=self.field_types)

@final
@classmethod
def class_has_field(cls, field: Type[Field], *, union_membership: UnionMembership) -> bool:
"""Behaves like `Target.has_field()`, but works as a classmethod rather than an instance
method."""
return cls.class_has_fields([field], union_membership=union_membership)

@final
@classmethod
def class_has_fields(
cls, fields: Iterable[Type[Field]], *, union_membership: UnionMembership
) -> bool:
"""Behaves like `Target.has_fields()`, but works as a classmethod rather than an instance
method."""
return cls._has_fields(
fields, registered_fields=(*cls.core_fields, *cls._find_plugin_fields(union_membership))
)


# TODO: add light-weight runtime type checking to these helper fields, such as ensuring that
Expand Down
31 changes: 31 additions & 0 deletions src/python/pants/engine/target_test.py
Expand Up @@ -154,16 +154,43 @@ class UnrelatedField(BoolField):
alias: ClassVar = "unrelated"
default: ClassVar = False

empty_union_membership = UnionMembership({})

tgt = HaskellTarget({}, address=Address.parse(":lib"))
assert tgt.has_fields([]) is True
assert HaskellTarget.class_has_fields([], union_membership=empty_union_membership) is True

assert tgt.has_fields([HaskellGhcExtensions]) is True
assert tgt.has_field(HaskellGhcExtensions) is True
assert (
HaskellTarget.class_has_fields(
[HaskellGhcExtensions], union_membership=empty_union_membership
)
is True
)
assert (
HaskellTarget.class_has_field(HaskellGhcExtensions, union_membership=empty_union_membership)
is True
)

assert tgt.has_fields([UnrelatedField]) is False
assert tgt.has_field(UnrelatedField) is False
assert (
HaskellTarget.class_has_fields([UnrelatedField], union_membership=empty_union_membership)
is False
)
assert (
HaskellTarget.class_has_field(UnrelatedField, union_membership=empty_union_membership)
is False
)

assert tgt.has_fields([HaskellGhcExtensions, UnrelatedField]) is False
assert (
HaskellTarget.class_has_fields(
[HaskellGhcExtensions, UnrelatedField], union_membership=empty_union_membership
)
is False
)


def test_primitive_field_hydration_is_eager() -> None:
Expand All @@ -186,9 +213,13 @@ class CustomField(BoolField):
tgt = HaskellTarget(
tgt_values, address=Address.parse(":lib"), union_membership=union_membership
)

assert tgt.field_types == (HaskellGhcExtensions, HaskellSources, CustomField)
assert tgt.core_fields == (HaskellGhcExtensions, HaskellSources)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True
assert HaskellTarget.class_has_field(CustomField, union_membership=union_membership) is True

assert tgt.get(CustomField).value is True

default_tgt = HaskellTarget(
Expand Down

0 comments on commit 250ed9d

Please sign in to comment.