From 79e431a4dac984f1b9a096d44a389c08e569ad73 Mon Sep 17 00:00:00 2001 From: Sean Stewart Date: Wed, 16 Oct 2024 09:26:57 -0400 Subject: [PATCH] fix: correct handling optional types `None` is always the last arg in an optional type, which prevents proper behavior for null values in the default case. This change adds short-circuiting to our routines to handle nullable types correctly. --- src/typelib/marshals/routines.py | 12 ++++++-- src/typelib/unmarshals/routines.py | 7 ++++- tests/unit/marshals/test_api.py | 5 ++++ tests/unit/marshals/test_routines.py | 40 +++++++++++++++++++------- tests/unit/unmarshals/test_api.py | 5 ++++ tests/unit/unmarshals/test_routines.py | 18 ++++++++++++ 6 files changed, 72 insertions(+), 15 deletions(-) diff --git a/src/typelib/marshals/routines.py b/src/typelib/marshals/routines.py index b0012d0..d49bea0 100644 --- a/src/typelib/marshals/routines.py +++ b/src/typelib/marshals/routines.py @@ -262,7 +262,7 @@ class UnionMarshaller(AbstractMarshaller[UnionT], tp.Generic[UnionT]): - [`UnionUnmarshaller`][typelib.unmarshals.routines.UnionUnmarshaller] """ - __slots__ = ("stack", "ordered_routines") + __slots__ = ("stack", "ordered_routines", "nullable") def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None): """Constructor. @@ -274,10 +274,11 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None """ super().__init__(t, context, var=var) self.stack = inspection.args(t) + self.nullable = inspection.isoptionaltype(t) self.ordered_routines = [self.context[typ] for typ in self.stack] def __call__(self, val: UnionT) -> serdes.MarshalledValueT: - """Unmarshal a value into the bound `UnionT`. + """Marshal a value into the bound `UnionT`. Args: val: The input value to unmarshal. @@ -285,8 +286,13 @@ def __call__(self, val: UnionT) -> serdes.MarshalledValueT: Raises: ValueError: If `val` cannot be marshalled via any member type. """ + if self.nullable and val is None: + return val + for routine in self.ordered_routines: - with contextlib.suppress(ValueError, TypeError, SyntaxError): + with contextlib.suppress( + ValueError, TypeError, SyntaxError, AttributeError + ): unmarshalled = routine(val) return unmarshalled diff --git a/src/typelib/unmarshals/routines.py b/src/typelib/unmarshals/routines.py index f257a4e..b4b5b24 100644 --- a/src/typelib/unmarshals/routines.py +++ b/src/typelib/unmarshals/routines.py @@ -678,6 +678,9 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None """ super().__init__(t, context, var=var) self.stack = inspection.args(t) + if inspection.isoptionaltype(t): + self.stack = (self.stack[-1], *self.stack[:-1]) + self.ordered_routines = [self.context[typ] for typ in self.stack] def __call__(self, val: tp.Any) -> UnionT: @@ -690,7 +693,9 @@ def __call__(self, val: tp.Any) -> UnionT: ValueError: If `val` cannot be unmarshalled into any member type. """ for routine in self.ordered_routines: - with contextlib.suppress(ValueError, TypeError, SyntaxError): + with contextlib.suppress( + ValueError, TypeError, SyntaxError, AttributeError + ): unmarshalled = routine(val) return unmarshalled diff --git a/tests/unit/marshals/test_api.py b/tests/unit/marshals/test_api.py index 6121d44..7209f3f 100644 --- a/tests/unit/marshals/test_api.py +++ b/tests/unit/marshals/test_api.py @@ -42,6 +42,11 @@ given_input=2, expected_output=2, ), + optional_none=dict( + given_type=typing.Optional[typing.Union[int, str]], + given_input=None, + expected_output=None, + ), datetime=dict( given_type=datetime.datetime, given_input=datetime.datetime.fromtimestamp(0, datetime.timezone.utc), diff --git a/tests/unit/marshals/test_routines.py b/tests/unit/marshals/test_routines.py index 51cf244..a6c34bf 100644 --- a/tests/unit/marshals/test_routines.py +++ b/tests/unit/marshals/test_routines.py @@ -126,7 +126,7 @@ def test_date_marshaller(given_input, expected_output): expected_output=datetime.datetime(1969, 12, 31).isoformat(), ), ) -def test_datetime_unmarshaller(given_input, expected_output): +def test_datetime_marshaller(given_input, expected_output): # Given given_marshaller = routines.DateTimeMarshaller(datetime.datetime, {}) # When @@ -141,7 +141,7 @@ def test_datetime_unmarshaller(given_input, expected_output): expected_output="00:00:00+00:00", ), ) -def test_time_unmarshaller(given_input, expected_output): +def test_time_marshaller(given_input, expected_output): # Given given_marshaller = routines.TimeMarshaller(datetime.time, {}) # When @@ -153,7 +153,7 @@ def test_time_unmarshaller(given_input, expected_output): @pytest.mark.suite( timedelta=dict(given_input=datetime.timedelta(seconds=1), expected_output="PT1S"), ) -def test_timedelta_unmarshaller(given_input, expected_output): +def test_timedelta_marshaller(given_input, expected_output): # Given given_marshaller = routines.TimeDeltaMarshaller(datetime.timedelta, {}) # When @@ -187,7 +187,7 @@ def test_mapping_marshaller(given_input, expected_output): expected_output=["field", "value"], ), ) -def test_iterable_unmarshaller(given_input, expected_output): +def test_iterable_marshaller(given_input, expected_output): # Given given_marshaller = routines.IterableMarshaller(typing.Iterable, {}) # When @@ -259,8 +259,26 @@ def test_literal_marshaller(given_input, given_literal, given_context, expected_ }, expected_output=1, ), + optional_date_none=dict( + given_input=None, + given_union=typing.Optional[datetime.date], + given_context={ + datetime.date: routines.DateMarshaller(datetime.date, {}), + type(None): routines.NoOpMarshaller(type(None), {}), + }, + expected_output=None, + ), + optional_date_date=dict( + given_input=datetime.date.today(), + given_union=typing.Optional[datetime.date], + given_context={ + datetime.date: routines.DateMarshaller(datetime.date, {}), + type(None): routines.NoOpMarshaller(type(None), {}), + }, + expected_output=datetime.date.today().isoformat(), + ), ) -def test_union_unmarshaller(given_input, given_union, given_context, expected_output): +def test_union_marshaller(given_input, given_union, given_context, expected_output): # Given given_marshaller = routines.UnionMarshaller(given_union, given_context) # When @@ -280,7 +298,7 @@ def test_union_unmarshaller(given_input, given_union, given_context, expected_ou expected_output={"field": 1}, ), ) -def test_subscripted_mapping_unmarshaller( +def test_subscripted_mapping_marshaller( given_input, given_mapping, given_context, expected_output ): # Given @@ -373,7 +391,7 @@ def test_subscripted_iterable_marshaller( expected_output=["field", 1], ), ) -def test_fixed_tuple_unmarshaller( +def test_fixed_tuple_marshaller( given_input, given_tuple, given_context, expected_output ): # Given @@ -419,7 +437,7 @@ def test_fixed_tuple_unmarshaller( given_input=models.TDict(field="data", value=1), ), ) -def test_structured_type_unmarshaller( +def test_structured_type_marshaller( given_input, given_cls, given_context, expected_output ): # Given @@ -456,12 +474,12 @@ def test_invalid_union(): given_marshaller(given_value) -def test_enum_unmarshaller(): +def test_enum_marshaller(): # Given - given_unmarshaller = routines.EnumMarshaller(models.GivenEnum, {}) + given_marshaller = routines.EnumMarshaller(models.GivenEnum, {}) given_value = models.GivenEnum.one expected_value = models.GivenEnum.one.value # When - unmarshalled = given_unmarshaller(given_value) + unmarshalled = given_marshaller(given_value) # Then assert unmarshalled == expected_value diff --git a/tests/unit/unmarshals/test_api.py b/tests/unit/unmarshals/test_api.py index ab86a89..c00dfaa 100644 --- a/tests/unit/unmarshals/test_api.py +++ b/tests/unit/unmarshals/test_api.py @@ -149,6 +149,11 @@ timestamp=datetime.datetime.fromtimestamp(0, datetime.timezone.utc) ), ), + optional_none=dict( + given_type=typing.Optional[typing.Union[int, str]], + given_input=None, + expected_output=None, + ), attrib_conflict=dict( given_type=models.Parent, given_input={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}}, diff --git a/tests/unit/unmarshals/test_routines.py b/tests/unit/unmarshals/test_routines.py index 5c8bad5..42c8de6 100644 --- a/tests/unit/unmarshals/test_routines.py +++ b/tests/unit/unmarshals/test_routines.py @@ -500,6 +500,24 @@ def test_literal_unmarshaller( }, expected_output=1, ), + optional_date_none=dict( + given_input=None, + given_union=typing.Optional[datetime.date], + given_context={ + datetime.date: routines.DateUnmarshaller(datetime.date, {}), + type(None): routines.NoOpUnmarshaller(type(None), {}), + }, + expected_output=None, + ), + optional_date_date=dict( + given_input=datetime.date.today().isoformat(), + given_union=typing.Optional[datetime.date], + given_context={ + datetime.date: routines.DateUnmarshaller(datetime.date, {}), + type(None): routines.NoneTypeUnmarshaller(type(None), {}), + }, + expected_output=datetime.date.today(), + ), ) def test_union_unmarshaller(given_input, given_union, given_context, expected_output): # Given