Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Literal inside Union support #237 #267

Merged
merged 15 commits into from
Apr 3, 2024
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ from adaptix import Retort
class Book:
title: str
price: int
author: str = "Unknown author"


data = {
Expand Down
1 change: 1 addition & 0 deletions docs/changelog/fragments/237.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for dumping ``Literal`` inside ``Union``.
4 changes: 1 addition & 3 deletions docs/examples/loading-and-dumping/tutorial/tldr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: disable-error-code="arg-type"
from dataclasses import dataclass

from adaptix import Retort
Expand All @@ -8,7 +7,6 @@
class Book:
title: str
price: int
author: str = "Unknown author"


data = {
Expand All @@ -21,4 +19,4 @@ class Book:

book = retort.load(data, Book)
assert book == Book(title="Fahrenheit 451", price=100)
assert retort.dump(book) == {**data, "author": "Unknown author"}
assert retort.dump(book) == data
4 changes: 2 additions & 2 deletions docs/loading-and-dumping/specific-types-behavior.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ For objects of types that are not listed in the union,
but which are a subclass of some union case, the base class dumper is used.
If there are several parents, it will be the selected class that appears first in ``.mro()`` list.

Also, builtin dumper can not work
with union containing non-class type hints like ``Union[Literal['foo', 'bar'], int]``.
Also, builtin dumper can work only with class type hints and ``Literal``.
For example, type hints like ``LiteralString | int`` can not be dumped.

Iterable subclasses
'''''''''''''''''''''
Expand Down
1 change: 0 additions & 1 deletion docs/loading-and-dumping/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ It can create models from mapping (loading) and create mappings from the model (


.. literalinclude:: /examples/loading-and-dumping/tutorial/tldr.py
:lines: 2-

All typing information is retrieved from your annotations, so is not required from you to provide any additional schema
or even change your dataclass decorators or class bases.
Expand Down
1 change: 0 additions & 1 deletion docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Example
==================

.. literalinclude:: /examples/loading-and-dumping/tutorial/tldr.py
:lines: 2-
:caption: Model loading and dumping
:name: loading-and-dumping-example

Expand Down
54 changes: 48 additions & 6 deletions src/adaptix/_internal/morphing/generic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from os import PathLike
from pathlib import Path
from typing import Any, Collection, Dict, Iterable, Literal, Sequence, Set, Type, Union
from typing import Any, Collection, Dict, Iterable, Literal, Optional, Sequence, Set, Type, Union

from ..common import Dumper, Loader
from ..compat import CompatExceptionGroup
Expand Down Expand Up @@ -382,10 +382,15 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return as_is_stub
return self._get_single_optional_dumper(not_none_dumper)

non_class_origins = [case.source for case in norm.args if not self._is_class_origin(case.origin)]
if non_class_origins:
forbidden_origins = [
case.source
for case in norm.args
if not self._is_class_origin(case.origin) and case.origin != Literal
]

if forbidden_origins:
raise CannotProvide(
f"All cases of union must be class, but found {non_class_origins}",
f"All cases of union must be class or Literal, but found {forbidden_origins}",
is_terminal=True,
is_demonstrative=True,
)
Expand All @@ -410,14 +415,51 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
dumper_type_dispatcher = ClassDispatcher(
{type(None) if case.origin is None else case.origin: dumper for case, dumper in zip(norm.args, dumpers)},
)
return self._get_dumper(dumper_type_dispatcher)

def _get_dumper(self, dumper_type_dispatcher: ClassDispatcher[Any, Dumper]) -> Dumper:
literal_dumper = self._get_dumper_for_literal(norm, dumpers, dumper_type_dispatcher)

if literal_dumper:
return literal_dumper

return self._produce_dumper(dumper_type_dispatcher)

def _produce_dumper(self, dumper_type_dispatcher: ClassDispatcher[Any, Dumper]) -> Dumper:
def union_dumper(data):
return dumper_type_dispatcher.dispatch(type(data))(data)

return union_dumper

def _produce_dumper_for_literal(
self,
dumper_type_dispatcher: ClassDispatcher[Any, Dumper],
literal_dumper: Dumper,
literal_cases: Sequence[Any],
) -> Dumper:
def union_dumper_with_literal(data):
if data in literal_cases:
return literal_dumper(data)
return dumper_type_dispatcher.dispatch(type(data))(data)

return union_dumper_with_literal

def _get_dumper_for_literal(
self,
norm: BaseNormType,
dumpers: Iterable[Any],
dumper_type_dispatcher: ClassDispatcher[Any, Dumper],
) -> Optional[Dumper]:
try:
literal_type, literal_dumper = next(
(union_case, dumper) for union_case, dumper
in zip(norm.args, dumpers)
if union_case.origin is Literal
)
except StopIteration:
return None

literal_cases = [strip_annotated(arg) for arg in literal_type.args]
return self._produce_dumper_for_literal(dumper_type_dispatcher, literal_dumper, literal_cases)

def _get_single_optional_dumper(self, dumper: Dumper) -> Dumper:
def optional_dumper(data):
if data is None:
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/morphing/generic_provider/test_union_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from decimal import Decimal
from typing import Callable, List, Literal, Optional, Union

import pytest
Expand Down Expand Up @@ -202,7 +203,7 @@ def test_bad_optional_dumping(retort, debug_trail):
),
with_notes(
CannotProvide(
message=f"All cases of union must be class, but found {[Callable[[int], str]]}",
message=f"All cases of union must be class or Literal, but found {[Callable[[int], str]]}",
is_demonstrative=True,
is_terminal=True,
),
Expand Down Expand Up @@ -262,3 +263,38 @@ def test_literal(strict_coercion, debug_trail):
assert dumper_("a") == "a"
assert dumper_(None) is None
assert dumper_("b") == "b"


@pytest.mark.parametrize(
["other_type", "value", "expected", "wrong_value"],
[
(
Decimal, Decimal(200.5), "200.5", [1, 2, 3],
),
(
Union[str, Decimal], "some string", "some string", [1, 2, 3],
),
],
)
def test_dump_literal_in_union(
strict_coercion,
debug_trail,
other_type,
value,
expected,
wrong_value,
):
retort = Retort()

dumper_ = retort.replace(
debug_trail=debug_trail,
).get_dumper(
Union[Literal[200, 300], other_type],
)

assert dumper_(200) == 200
assert dumper_(300) == 300
assert dumper_(value) == expected

with pytest.raises(KeyError):
dumper_(wrong_value)
Loading