diff --git a/src/attrs/__init__.py b/src/attrs/__init__.py index 0c2481561..00acbdb6a 100644 --- a/src/attrs/__init__.py +++ b/src/attrs/__init__.py @@ -43,6 +43,7 @@ "AttrsInstance", "cmp_using", "converters", + "custom_fields", "define", "evolve", "exceptions", diff --git a/src/attrs/custom_fields.py b/src/attrs/custom_fields.py new file mode 100644 index 000000000..db2ee7aa9 --- /dev/null +++ b/src/attrs/custom_fields.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import typing + +from typing_extensions import Protocol + +from attr._make import _make_attr_tuple_class +from attrs import Attribute, AttrsInstance, fields +from attrs import resolve_types as _resolve_types + + +__all__ = ["custom_fields"] + +T = typing.TypeVar("T") + + +class AttributeModel(Protocol[T]): + """Custom attributes must conform to this.""" + + @classmethod + def _from_attrs_attribute( + cls: type[AttributeModel], + cl: type[AttrsInstance], + attribute: Attribute[T], + ) -> AttributeModel[T]: + """Create a custom attribute model from an `attrs.Attribute`.""" + ... + + +def custom_fields( + cls: type[AttrsInstance], + attribute_model: type[AttributeModel], + resolve_types: bool = False, +): + """ + Return the attrs fields tuple for cls with the provided attribute model. + + :param type cls: Class to introspect. + :param attribute_model: The attribute model to use. + :param resolve_types: Whether to resolve the class types first. + + :raise TypeError: If *cls* is not a class. + :raise attrs.exceptions.NotAnAttrsClassError: If *cls* is not an *attrs* + class. + + :rtype: tuple (with name accessors) of `attribute_model`. + + .. versionadded:: 23.2.0 + """ + attrs = getattr(cls, f"__attrs_{id(attribute_model)}__", None) + + if attrs is None: + if resolve_types: + _resolve_types(cls) + base_attrs = fields(cls) + AttrsClass = _make_attr_tuple_class( + cls.__name__, [a.name for a in base_attrs] + ) + attrs = AttrsClass( + attribute_model._from_attrs_attribute(cls, a) for a in base_attrs + ) + setattr(cls, f"__attrs_{id(attribute_model)}__", attrs) + + return attrs diff --git a/tests/test_custom_fields.py b/tests/test_custom_fields.py new file mode 100644 index 000000000..f178504cf --- /dev/null +++ b/tests/test_custom_fields.py @@ -0,0 +1,76 @@ +"""Tests for the custom attributes functionality.""" +from __future__ import annotations + +from functools import partial +from typing import Generic, TypeVar + +from attrs import Attribute, AttrsInstance, define +from attrs.custom_fields import custom_fields + + +T = TypeVar("T") + + +@define +class CustomAttribute(Generic[T]): + """A custom attribute, for tests.""" + + cl: type[AttrsInstance] + name: str + attribute_type: T + + @classmethod + def _from_attrs_attribute( + cls, attrs_cls: type[AttrsInstance], attribute: Attribute[T] + ): + return cls(attrs_cls, attribute.name, attribute.type) + + +cust_fields = partial(custom_fields, attribute_model=CustomAttribute) +cust_resolved_fields = partial( + custom_fields, attribute_model=CustomAttribute, resolve_types=True +) + + +def test_simple_custom_fields(): + """Simple custom attribute overriding works.""" + + @define + class Test: + a: int + b: float + + for _ in range(2): + # Do it twice to test caching. + f = cust_fields(Test) + + assert isinstance(f.a, CustomAttribute) + assert isinstance(f.b, CustomAttribute) + + assert not hasattr(f, "c") + + assert f.a.name == "a" + assert f.a.cl is Test + assert f.a.attribute_type == "int" + + +def test_resolved_custom_fields(): + """Resolved custom attributes work.""" + + @define + class Test: + a: int + b: float + + for _ in range(2): + # Do it twice to test caching. + f = cust_resolved_fields(Test) + + assert isinstance(f.a, CustomAttribute) + assert isinstance(f.b, CustomAttribute) + + assert not hasattr(f, "c") + + assert f.a.name == "a" + assert f.a.cl is Test + assert f.a.attribute_type is int