Skip to content

Commit

Permalink
Fix desert handling of metadata not being passed to fields
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Apr 27, 2023
1 parent 0c43ff5 commit 61e8105
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
56 changes: 31 additions & 25 deletions src/desert/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def field_for_schema(

if default is not marshmallow.missing:
desert_metadata.setdefault("dump_default", default)
desert_metadata.setdefault("allow_none", True)
desert_metadata.setdefault("load_default", default)
desert_metadata.setdefault("allow_none", True)

field = None

Expand All @@ -235,9 +235,20 @@ def field_for_schema(
field.metadata.update(metadata)
return field

field_args = {
k: v
for k, v in desert_metadata.items()
if k
in [
"dump_default",
"load_default",
"allow_none",
]
}

# Base types
if not field and typ in _native_to_marshmallow:
field = _native_to_marshmallow[typ](dump_default=default)
field = _native_to_marshmallow[typ](**field_args)

# Generic types
origin = typing_inspect.get_origin(typ)
Expand All @@ -253,16 +264,18 @@ def field_for_schema(
collections.abc.Sequence,
collections.abc.MutableSequence,
):
field = marshmallow.fields.List(field_for_schema(arguments[0]))
field = marshmallow.fields.List(
field_for_schema(arguments[0]), **field_args
)

if origin in (tuple, t.Tuple) and Ellipsis not in arguments:
field = marshmallow.fields.Tuple( # type: ignore[no-untyped-call]
tuple(field_for_schema(arg) for arg in arguments)
tuple(field_for_schema(arg) for arg in arguments), **field_args
)
elif origin in (tuple, t.Tuple) and Ellipsis in arguments:

field = VariadicTuple(
field_for_schema(only(arg for arg in arguments if arg != Ellipsis))
field_for_schema(only(arg for arg in arguments if arg != Ellipsis)),
**field_args,
)
elif origin in (
dict,
Expand All @@ -275,22 +288,15 @@ def field_for_schema(
field = marshmallow.fields.Dict(
keys=field_for_schema(arguments[0]),
values=field_for_schema(arguments[1]),
**field_args,
)
elif typing_inspect.is_optional_type(typ):
[subtyp] = (t for t in arguments if t is not NoneType)
# Treat optional types as types with a None default
metadata[_DESERT_SENTINEL]["dump_default"] = metadata.get(
"dump_default", None
)
metadata[_DESERT_SENTINEL]["load_default"] = metadata.get(
"load_default", None
)
metadata[_DESERT_SENTINEL]["required"] = False

field = field_for_schema(subtyp, metadata=metadata, default=None)
field.dump_default = None
field.load_default = None
field.allow_none = True
metadata[_DESERT_SENTINEL]["allow_none"] = True
if default is marshmallow.missing:
default = None
field = field_for_schema(subtyp, metadata=metadata, default=default)

elif typing_inspect.is_union_type(typ):
subfields = [field_for_schema(subtyp) for subtyp in arguments]
Expand All @@ -302,7 +308,7 @@ def field_for_schema(
newtype_supertype = getattr(typ, "__supertype__", None)
if newtype_supertype and typing_inspect.is_new_type(typ):
metadata.setdefault("description", typ.__name__)
field = field_for_schema(newtype_supertype, default=default)
field = field_for_schema(newtype_supertype, metadata=metadata, default=default)

# enumerations
if type(typ) is enum.EnumMeta:
Expand All @@ -315,7 +321,7 @@ def field_for_schema(

if field is None:
nested = forward_reference or class_schema(typ)
field = marshmallow.fields.Nested(nested)
field = marshmallow.fields.Nested(nested, **field_args)

field.metadata.update(metadata)

Expand Down Expand Up @@ -350,11 +356,11 @@ def _get_field_default(
if isinstance(field, dataclasses.Field):
# misc: https://github.com/python/mypy/issues/10750
# comparison-overlap: https://github.com/python/typeshed/pull/5900
if field.default_factory != dataclasses.MISSING:
return dataclasses.MISSING
if field.default is dataclasses.MISSING:
return marshmallow.missing
return field.default
if field.default_factory is not dataclasses.MISSING:
return field.default_factory
if field.default is not dataclasses.MISSING:
return field.default
return marshmallow.missing
elif isinstance(field, attr.Attribute):
if field.default == attr.NOTHING:
return marshmallow.missing
Expand Down
11 changes: 11 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ class A:
assert data == A(None) # type: ignore[call-arg]


def test_optional_default(module: DataclassModule) -> None:
"""Setting an optional type allows passing None."""

@module.dataclass
class A:
x: t.Optional[int] = 1

data = desert.schema_class(A)().load({})
assert data == A(1) # type: ignore[call-arg]


def test_custom_field(module: DataclassModule) -> None:
@module.dataclass
class A:
Expand Down

0 comments on commit 61e8105

Please sign in to comment.