From abdbfaa9f238d9d8e08df6a71ce57d64056041ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Fri, 20 Oct 2023 13:27:59 +0200 Subject: [PATCH 1/4] Tweak docs --- HISTORY.md | 10 ++++---- docs/unions.md | 49 ++++++++++++++++++++++++++++++++++-- src/cattrs/disambiguators.py | 10 +++++--- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0f71a2ef..61e48929 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,19 +8,19 @@ - **Potentially breaking**: {py:func}`cattrs.gen.make_dict_structure_fn` and {py:func}`cattrs.gen.typeddicts.make_dict_structure_fn` will use the values for the `detailed_validation` and `forbid_extra_keys` parameters from the given converter by default now. If you're using these functions directly, the old behavior can be restored by passing in the desired values directly. ([#410](https://github.com/python-attrs/cattrs/issues/410) [#411](https://github.com/python-attrs/cattrs/pull/411)) +- **Potentially breaking**: The default union structuring strategy will also use fields annotated as `typing.Literal` to help guide structuring. + ([#391](https://github.com/python-attrs/cattrs/pull/391)) - Python 3.12 is now supported. Python 3.7 is no longer supported; use older releases there. ([#424](https://github.com/python-attrs/cattrs/pull/424)) +- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough). - Introduce the `use_class_methods` strategy. Learn more [here](https://catt.rs/en/latest/strategies.html#using-class-specific-structure-and-unstructure-methods). ([#405](https://github.com/python-attrs/cattrs/pull/405)) -- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough). - The `omit` parameter of {py:func}`cattrs.override` is now of type `bool | None` (from `bool`). `None` is the new default and means to apply default _cattrs_ handling to the attribute, which is to omit the attribute if it's marked as `init=False`, and keep it otherwise. - Fix {py:func}`format_exception() ` parameter working for recursive calls to {py:func}`transform_error `. ([#389](https://github.com/python-attrs/cattrs/issues/389)) - [_attrs_ aliases](https://www.attrs.org/en/stable/init.html#private-attributes-and-aliases) are now supported, although aliased fields still map to their attribute name instead of their alias by default when un/structuring. ([#322](https://github.com/python-attrs/cattrs/issues/322) [#391](https://github.com/python-attrs/cattrs/pull/391)) -- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry. -- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/). - Fix TypedDicts with periods in their field names. ([#376](https://github.com/python-attrs/cattrs/issues/376) [#377](https://github.com/python-attrs/cattrs/pull/377)) - Optimize and improve unstructuring of `Optional` (unions of one type and `None`). @@ -45,10 +45,10 @@ ([#420](https://github.com/python-attrs/cattrs/pull/420)) - Add support for `datetime.date`s to the PyYAML preconfigured converter. ([#393](https://github.com/python-attrs/cattrs/issues/393)) +- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry. +- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/). - Remove some unused lines in the unstructuring code. ([#416](https://github.com/python-attrs/cattrs/pull/416)) -- Disambiguate a union of attrs classes where there's a `typing.Literal` tag of some sort. - ([#391](https://github.com/python-attrs/cattrs/pull/391)) ## 23.1.2 (2023-06-02) diff --git a/docs/unions.md b/docs/unions.md index 4385bc91..a9876c68 100644 --- a/docs/unions.md +++ b/docs/unions.md @@ -2,7 +2,7 @@ This sections contains information for advanced union handling. -As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes automatically. +As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy). More complex cases require converter customization (since there are many ways of handling unions). _cattrs_ also comes with a number of strategies to help handle unions: @@ -10,7 +10,52 @@ _cattrs_ also comes with a number of strategies to help handle unions: - [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below - [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters -## Unstructuring unions with extra metadata +## Default Union Strategy + +For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated. + +Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways. + +First, it will look for `Literal` fields. +If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field. + +```python +from typing import Literal + +@define +class ClassA: + field_one: Literal["one"] + +@define +class ClassB: + field_one: Literal["two"] +``` + +In this case, a payload containing `{"field_one": "one"}` will produce an instance of `ClassA`. + +If there are no appropriate fields, the strategy will examine the classes for **unique required fields**. + +So, given a union of `ClassA` and `ClassB`: + +```python +@define +class ClassA: + field_one: str + field_with_default: str = "a default" + +@define +class ClassB: + field_two: str +``` + +the strategy will determine that if a payload contains the key `field_one` it should be handled as `ClassA`, and if it contains the key `field_two` it should be handled as `ClassB`. +The field `field_with_default` will not be considered since it has a default value, so it gets treated as optional. + +```{versionchanged} 23.2.0 +Literals can now be potentially used to disambiguate. +``` + +## Unstructuring Unions with Extra Metadata ```{note} _cattrs_ comes with the [tagged unions strategy](strategies.md#tagged-unions-strategy) for handling this exact use-case since version 23.1. diff --git a/src/cattrs/disambiguators.py b/src/cattrs/disambiguators.py index aa7671df..38c32974 100644 --- a/src/cattrs/disambiguators.py +++ b/src/cattrs/disambiguators.py @@ -2,9 +2,9 @@ from collections import OrderedDict, defaultdict from functools import reduce from operator import or_ -from typing import Any, Callable, Dict, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Set, Type, Union -from attr import NOTHING, fields, fields_dict +from attrs import NOTHING, fields, fields_dict from cattrs._compat import get_args, get_origin, is_literal @@ -12,7 +12,7 @@ def create_default_dis_func( *classes: Type[Any], ) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]: - """Given attr classes, generate a disambiguation function. + """Given attrs classes, generate a disambiguation function. The function is based on unique fields or unique values.""" if len(classes) < 2: @@ -33,13 +33,15 @@ def create_default_dis_func( for cl in classes ] - discriminators = cls_candidates[0] + # literal field names common to all members + discriminators: Set[str] = cls_candidates[0] for possible_discriminators in cls_candidates: discriminators &= possible_discriminators best_result = None best_discriminator = None for discriminator in discriminators: + # maps Literal values (strings, ints...) to classes mapping = defaultdict(list) for cl in classes: From 68c215ecb0fe9a15e5ebd3e098705545ad883678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sat, 21 Oct 2023 16:00:36 +0200 Subject: [PATCH 2/4] Docs and tests --- docs/unions.md | 15 +++ src/cattrs/converters.py | 32 ++--- src/cattrs/disambiguators.py | 115 +++++++++------- ...isambigutors.py => test_disambiguators.py} | 125 +++++++++++++----- 4 files changed, 183 insertions(+), 104 deletions(-) rename tests/{test_disambigutors.py => test_disambiguators.py} (61%) diff --git a/docs/unions.md b/docs/unions.md index a9876c68..ac031f6b 100644 --- a/docs/unions.md +++ b/docs/unions.md @@ -33,6 +33,21 @@ class ClassB: In this case, a payload containing `{"field_one": "one"}` will produce an instance of `ClassA`. +````{note} +The following snippet can be used to disable the use of literal fields, restoring the previous behavior. + +```python +from functools import partial +from cattrs.disambiguators import is_supported_union + +converter.register_structure_hook_factory( + is_supported_union, + partial(converter._gen_attrs_union_structure, use_literals=False), +) +``` + +```` + If there are no appropriate fields, the strategy will examine the classes for **unique required fields**. So, given a union of `ClassA` and `ClassB`: diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index 1b42512e..e97d7cf8 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -19,9 +19,9 @@ Union, ) -from attr import Attribute -from attr import has as attrs_has -from attr import resolve_types +from attrs import Attribute +from attrs import has as attrs_has +from attrs import resolve_types from ._compat import ( FrozenSetSubscriptable, @@ -55,7 +55,7 @@ is_typeddict, is_union_type, ) -from .disambiguators import create_default_dis_func +from .disambiguators import create_default_dis_func, is_supported_union from .dispatch import MultiStrategyDispatch from .errors import ( IterableValidationError, @@ -96,16 +96,6 @@ def _subclass(typ: Type) -> Callable[[Type], bool]: return lambda cls: issubclass(cls, typ) -def is_attrs_union(typ: Type) -> bool: - return is_union_type(typ) and all(has(get_origin(e) or e) for e in typ.__args__) - - -def is_attrs_union_or_none(typ: Type) -> bool: - return is_union_type(typ) and all( - e is NoneType or has(get_origin(e) or e) for e in typ.__args__ - ) - - def is_optional(typ: Type) -> bool: return is_union_type(typ) and NoneType in typ.__args__ and len(typ.__args__) == 2 @@ -204,7 +194,7 @@ def __init__( (is_frozenset, self._structure_frozenset), (is_tuple, self._structure_tuple), (is_mapping, self._structure_dict), - (is_attrs_union_or_none, self._gen_attrs_union_structure, True), + (is_supported_union, self._gen_attrs_union_structure, True), ( lambda t: is_union_type(t) and t in self._union_struct_registry, self._structure_union, @@ -411,17 +401,19 @@ def _gen_structure_generic(self, cl: Type[T]) -> DictStructureFn[T]: ) def _gen_attrs_union_structure( - self, cl: Any + self, cl: Any, use_literals: bool = True ) -> Callable[[Any, Type[T]], Optional[Type[T]]]: """ Generate a structuring function for a union of attrs classes (and maybe None). + + :param use_literals: Whether to consider literal fields. """ - dis_fn = self._get_dis_func(cl) + dis_fn = self._get_dis_func(cl, use_literals=use_literals) has_none = NoneType in cl.__args__ if has_none: - def structure_attrs_union(obj, _): + def structure_attrs_union(obj, _) -> cl: if obj is None: return None return self.structure(obj, dis_fn(obj)) @@ -719,7 +711,7 @@ def _structure_tuple(self, obj: Any, tup: Type[T]) -> T: return res @staticmethod - def _get_dis_func(union: Any) -> Callable[[Any], Type]: + def _get_dis_func(union: Any, use_literals: bool) -> Callable[[Any], Type]: """Fetch or try creating a disambiguation function for a union.""" union_types = union.__args__ if NoneType in union_types: # type: ignore @@ -738,7 +730,7 @@ def _get_dis_func(union: Any) -> Callable[[Any], Type]: type_=union, ) - return create_default_dis_func(*union_types) + return create_default_dis_func(*union_types, use_literals=use_literals) def __deepcopy__(self, _) -> "BaseConverter": return self.copy() diff --git a/src/cattrs/disambiguators.py b/src/cattrs/disambiguators.py index 38c32974..281954e1 100644 --- a/src/cattrs/disambiguators.py +++ b/src/cattrs/disambiguators.py @@ -6,69 +6,82 @@ from attrs import NOTHING, fields, fields_dict -from cattrs._compat import get_args, get_origin, is_literal +from ._compat import get_args, get_origin, has, is_literal, is_union_type + +__all__ = ("is_supported_union", "create_default_dis_func") + +NoneType = type(None) + + +def is_supported_union(typ: Type) -> bool: + """Whether the type is a union of attrs classes.""" + return is_union_type(typ) and all( + e is NoneType or has(get_origin(e) or e) for e in typ.__args__ + ) def create_default_dis_func( - *classes: Type[Any], + *classes: Type[Any], use_literals: bool = True ) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]: """Given attrs classes, generate a disambiguation function. - The function is based on unique fields or unique values.""" + The function is based on unique fields or unique values. + + :param use_literals: Whether to try using fields annotated as literals for + disambiguation. + """ if len(classes) < 2: raise ValueError("At least two classes required.") # first, attempt for unique values + if use_literals: + # requirements for a discriminator field: + # (... TODO: a single fallback is OK) + # - it must always be enumerated + cls_candidates = [ + {at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)} + for cl in classes + ] + + # literal field names common to all members + discriminators: Set[str] = cls_candidates[0] + for possible_discriminators in cls_candidates: + discriminators &= possible_discriminators + + best_result = None + best_discriminator = None + for discriminator in discriminators: + # maps Literal values (strings, ints...) to classes + mapping = defaultdict(list) + + for cl in classes: + for key in get_args( + fields_dict(get_origin(cl) or cl)[discriminator].type + ): + mapping[key].append(cl) + + if best_result is None or max(len(v) for v in mapping.values()) <= max( + len(v) for v in best_result.values() + ): + best_result = mapping + best_discriminator = discriminator + + if ( + best_result + and best_discriminator + and max(len(v) for v in best_result.values()) != len(classes) + ): + final_mapping = { + k: v[0] if len(v) == 1 else Union[tuple(v)] + for k, v in best_result.items() + } - # requirements for a discriminator field: - # (... TODO: a single fallback is OK) - # - it must be *required* - # - it must always be enumerated - cls_candidates = [ - { - at.name - for at in fields(get_origin(cl) or cl) - if at.default is NOTHING and is_literal(at.type) - } - for cl in classes - ] - - # literal field names common to all members - discriminators: Set[str] = cls_candidates[0] - for possible_discriminators in cls_candidates: - discriminators &= possible_discriminators - - best_result = None - best_discriminator = None - for discriminator in discriminators: - # maps Literal values (strings, ints...) to classes - mapping = defaultdict(list) - - for cl in classes: - for key in get_args(fields_dict(get_origin(cl) or cl)[discriminator].type): - mapping[key].append(cl) + def dis_func(data: Mapping[Any, Any]) -> Optional[Type]: + if not isinstance(data, Mapping): + raise ValueError("Only input mappings are supported.") + return final_mapping[data[best_discriminator]] - if best_result is None or max(len(v) for v in mapping.values()) <= max( - len(v) for v in best_result.values() - ): - best_result = mapping - best_discriminator = discriminator - - if ( - best_result - and best_discriminator - and max(len(v) for v in best_result.values()) != len(classes) - ): - final_mapping = { - k: v[0] if len(v) == 1 else Union[tuple(v)] for k, v in best_result.items() - } - - def dis_func(data: Mapping[Any, Any]) -> Optional[Type]: - if not isinstance(data, Mapping): - raise ValueError("Only input mappings are supported.") - return final_mapping[data[best_discriminator]] - - return dis_func + return dis_func # next, attempt for unique keys diff --git a/tests/test_disambigutors.py b/tests/test_disambiguators.py similarity index 61% rename from tests/test_disambigutors.py rename to tests/test_disambiguators.py index 4fd37f13..4802211a 100644 --- a/tests/test_disambigutors.py +++ b/tests/test_disambiguators.py @@ -1,12 +1,16 @@ """Tests for auto-disambiguators.""" from typing import Any, Literal, Union -import attr import pytest -from attrs import NOTHING, define +from attrs import NOTHING, asdict, define, field, fields from hypothesis import HealthCheck, assume, given, settings -from cattrs.disambiguators import create_default_dis_func, create_uniq_field_dis_func +from cattrs import Converter +from cattrs.disambiguators import ( + create_default_dis_func, + create_uniq_field_dis_func, + is_supported_union, +) from .untyped import simple_classes @@ -14,7 +18,7 @@ def test_edge_errors(): """Edge input cases cause errors.""" - @attr.s + @define class A: pass @@ -25,7 +29,7 @@ class A: with pytest.raises(ValueError): create_default_dis_func(A) - @attr.s + @define class B: pass @@ -36,13 +40,13 @@ class B: with pytest.raises(ValueError): create_default_dis_func(A, B) - @attr.s + @define class C: - a = attr.ib() + a = field() - @attr.s + @define class D: - a = attr.ib() + a = field() with pytest.raises(ValueError): # No unique fields on either class. @@ -52,23 +56,23 @@ class D: # No discriminator candidates create_default_dis_func(C, D) - @attr.s + @define class E: pass - @attr.s + @define class F: - b = attr.ib(default=Any) + b = None with pytest.raises(ValueError): # no usable non-default attributes create_uniq_field_dis_func(E, F) - @define() + @define class G: x: Literal[1] - @define() + @define class H: x: Literal[1] @@ -82,18 +86,18 @@ def test_fallback(cl_and_vals): """The fallback case works.""" cl, vals, kwargs = cl_and_vals - assume(attr.fields(cl)) # At least one field. + assume(fields(cl)) # At least one field. - @attr.s + @define class A: pass fn = create_uniq_field_dis_func(A, cl) assert fn({}) is A - assert fn(attr.asdict(cl(*vals, **kwargs))) is cl + assert fn(asdict(cl(*vals, **kwargs))) is cl - attr_names = {a.name for a in attr.fields(cl)} + attr_names = {a.name for a in fields(cl)} if "xyz" not in attr_names: assert fn({"xyz": 1}) is A # Uses the fallback. @@ -106,31 +110,31 @@ def test_disambiguation(cl_and_vals_a, cl_and_vals_b): cl_a, vals_a, kwargs_a = cl_and_vals_a cl_b, vals_b, kwargs_b = cl_and_vals_b - req_a = {a.name for a in attr.fields(cl_a)} - req_b = {a.name for a in attr.fields(cl_b)} + req_a = {a.name for a in fields(cl_a)} + req_b = {a.name for a in fields(cl_b)} assume(len(req_a)) assume(len(req_b)) assume((req_a - req_b) or (req_b - req_a)) for attr_name in req_a - req_b: - assume(getattr(attr.fields(cl_a), attr_name).default is NOTHING) + assume(getattr(fields(cl_a), attr_name).default is NOTHING) for attr_name in req_b - req_a: - assume(getattr(attr.fields(cl_b), attr_name).default is NOTHING) + assume(getattr(fields(cl_b), attr_name).default is NOTHING) fn = create_uniq_field_dis_func(cl_a, cl_b) - assert fn(attr.asdict(cl_a(*vals_a, **kwargs_a))) is cl_a + assert fn(asdict(cl_a(*vals_a, **kwargs_a))) is cl_a # not too sure of properties of `create_default_dis_func` def test_disambiguate_from_discriminated_enum(): # can it find any discriminator? - @define() + @define class A: a: Literal[0] - @define() + @define class B: a: Literal[1] @@ -139,12 +143,12 @@ class B: assert fn({"a": 1}) is B # can it find the better discriminator? - @define() + @define class C: a: Literal[0] b: Literal[1] - @define() + @define class D: a: Literal[0] b: Literal[0] @@ -155,16 +159,16 @@ class D: # can it handle multiple tiers of discriminators? # (example inspired by Discord's gateway's discriminated union) - @define() + @define class E: op: Literal[1] - @define() + @define class F: op: Literal[0] t: Literal["MESSAGE_CREATE"] - @define() + @define class G: op: Literal[0] t: Literal["MESSAGE_UPDATE"] @@ -174,18 +178,73 @@ class G: assert fn({"op": 0, "t": "MESSAGE_CREATE"}) is Union[F, G] # can it handle multiple literals? - @define() + @define class H: a: Literal[1] - @define() + @define class J: a: Literal[0, 1] - @define() + @define class K: a: Literal[0] fn = create_default_dis_func(H, J, K) assert fn({"a": 1}) is Union[H, J] assert fn({"a": 0}) is Union[J, K] + + +def test_default_no_literals(): + """The default disambiguator can skip literals.""" + + @define + class A: + a: Literal["a"] = "a" + + @define + class B: + a: Literal["b"] = "b" + + default = create_default_dis_func(A, B) # Should work. + assert default({"a": "a"}) is A + + with pytest.raises(ValueError): + create_default_dis_func(A, B, use_literals=False) + + @define + class C: + b: int + a: Literal["a"] = "a" + + @define + class D: + a: Literal["b"] = "b" + + default = create_default_dis_func(C, D) # Should work. + assert default({"a": "a"}) is C + + no_lits = create_default_dis_func(C, D, use_literals=False) + assert no_lits({"a": "a", "b": 1}) is C + assert no_lits({"a": "a"}) is D + + +def test_converter_no_literals(converter: Converter): + """A converter can be configured to skip literals.""" + from functools import partial + + converter.register_structure_hook_factory( + is_supported_union, + partial(converter._gen_attrs_union_structure, use_literals=False), + ) + + @define + class C: + b: int + a: Literal["a"] = "a" + + @define + class D: + a: Literal["b"] = "b" + + assert converter.structure({}, Union[C, D]) == D() From fe43f684e624fc438862d08eda9a5e83fe934544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sat, 21 Oct 2023 17:48:04 +0200 Subject: [PATCH 3/4] Restore default --- src/cattrs/converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index e97d7cf8..22f88271 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -711,7 +711,7 @@ def _structure_tuple(self, obj: Any, tup: Type[T]) -> T: return res @staticmethod - def _get_dis_func(union: Any, use_literals: bool) -> Callable[[Any], Type]: + def _get_dis_func(union: Any, use_literals: bool = True) -> Callable[[Any], Type]: """Fetch or try creating a disambiguation function for a union.""" union_types = union.__args__ if NoneType in union_types: # type: ignore From 7ae4927ce8806dd05e0fa18ef98599c0a2c9fc44 Mon Sep 17 00:00:00 2001 From: Tin Tvrtkovic Date: Sat, 21 Oct 2023 23:59:24 +0200 Subject: [PATCH 4/4] Remove unused import --- tests/test_disambiguators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_disambiguators.py b/tests/test_disambiguators.py index 4802211a..5ec9ad7c 100644 --- a/tests/test_disambiguators.py +++ b/tests/test_disambiguators.py @@ -1,5 +1,5 @@ """Tests for auto-disambiguators.""" -from typing import Any, Literal, Union +from typing import Literal, Union import pytest from attrs import NOTHING, asdict, define, field, fields