diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 8e55aa9fd75..1d8aeecc49c 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -467,9 +467,9 @@ def field_types(self) -> Tuple[Type[Field], ...]: @final @memoized_classproperty def _plugin_field_cls(cls) -> type: - # NB: We ensure that each Target subtype has its own `PluginField` class so that - # registering a plugin field doesn't leak across target types. - + # Use the `PluginField` of the first `Target`-subclass ancestor as a base to ours, so that + # we inherit the registered fields. E.g. If I inherit from `PythonSourceTarget`, I want all + # the registered fields on `PythonSourceTarget` to also be registered for me. baseclass = ( object if cast("Type[Target]", cls) is Target diff --git a/src/python/pants/engine/unions.py b/src/python/pants/engine/unions.py index 13dd10c75f2..8918438c7bb 100644 --- a/src/python/pants/engine/unions.py +++ b/src/python/pants/engine/unions.py @@ -72,14 +72,16 @@ def from_rules(cls, rules: Iterable[UnionRule]) -> UnionMembership: mapping: DefaultDict[type, OrderedSet[type]] = defaultdict(OrderedSet) for rule in rules: mapping[rule.union_base].add(rule.union_member) - # Subclassed union bases should inherit the superclass's union members. + + # Base union classes inherit the members of any subclasses that are also unions bases = list(mapping.keys()) while len(bases) > 0: union_base = bases.pop() for sub_union in union_base.__subclasses__(): - if sub_union not in mapping: - bases.append(sub_union) - mapping[sub_union].update(mapping[union_base]) + if is_union(sub_union): + if sub_union not in mapping: + bases.append(sub_union) + mapping[sub_union].update(mapping[union_base]) return cls(mapping) def __init__(self, union_rules: Mapping[type, Iterable[type]]) -> None: diff --git a/src/python/pants/engine/unions_test.py b/src/python/pants/engine/unions_test.py index f302f180e14..a36651f6e0f 100644 --- a/src/python/pants/engine/unions_test.py +++ b/src/python/pants/engine/unions_test.py @@ -5,17 +5,44 @@ from pants.util.ordered_set import FrozenOrderedSet -def test_union_membership_from_rules() -> None: +def test_simple() -> None: @union - class Base: + class Fruit: pass - class A: + class Banana(Fruit): pass - class B: + class Apple(Fruit): pass - assert UnionMembership.from_rules([UnionRule(Base, A), UnionRule(Base, B)]) == UnionMembership( - {Base: FrozenOrderedSet([A, B])} + @union + class CitrusFruit(Fruit): + pass + + class Orange(CitrusFruit): + pass + + @union + class Vegetable: + pass + + class Potato: # Doesn't _have_ to inherit from the union + pass + + union_membership = UnionMembership.from_rules( + [ + UnionRule(Fruit, Banana), + UnionRule(Fruit, Apple), + UnionRule(CitrusFruit, Orange), + UnionRule(Vegetable, Potato), + ] + ) + + assert union_membership == UnionMembership( + { + Fruit: FrozenOrderedSet([Banana, Apple]), + CitrusFruit: FrozenOrderedSet([Orange, Banana, Apple]), + Vegetable: FrozenOrderedSet([Potato]), + } )