Skip to content

Commit

Permalink
fix: Add types in typing module to scope
Browse files Browse the repository at this point in the history
`iter_type` function produces the types in typing module now.

e.g.  `iter_types(Dict[str, int])` produces (Dict, str, int)
  • Loading branch information
yukinarit committed Nov 18, 2021
1 parent 05e0b6f commit e12e802
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 36 deletions.
60 changes: 54 additions & 6 deletions serde/compat.py
Expand Up @@ -72,9 +72,17 @@ def typename(typ) -> str:
'Any'
"""
if is_opt(typ):
return f'Optional[{typename(type_args(typ)[0])}]'
args = type_args(typ)
if args:
return f'Optional[{typename(type_args(typ)[0])}]'
else:
return 'Optional'
elif is_union(typ):
return f'Union[{", ".join([typename(e) for e in union_args(typ)])}]'
args = union_args(typ)
if args:
return f'Union[{", ".join([typename(e) for e in args])}]'
else:
return 'Union'
elif is_list(typ):
args = type_args(typ)
if args:
Expand All @@ -98,7 +106,11 @@ def typename(typ) -> str:
else:
return 'Dict'
elif is_tuple(typ):
return f'Tuple[{", ".join([typename(e) for e in type_args(typ)])}]'
args = type_args(typ)
if args:
return f'Tuple[{", ".join([typename(e) for e in args])}]'
else:
return 'Tuple'
elif typ is Any:
return 'Any'
else:
Expand Down Expand Up @@ -170,9 +182,12 @@ def dataclass_fields(cls: Type) -> Iterator:
return iter(raw_fields)


def iter_types(cls: Type) -> Iterator[Type]:
def iter_types(cls: Type) -> Iterator[Union[Type, typing.Any]]:
"""
Iterate field types recursively.
The correct return type is `Iterator[Union[Type, typing._specialform]],
but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
"""
if is_dataclass(cls):
yield cls
Expand All @@ -181,20 +196,30 @@ def iter_types(cls: Type) -> Iterator[Type]:
elif isinstance(cls, str):
yield cls
elif is_opt(cls):
yield Optional
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_union(cls):
yield Union
for arg in type_args(cls):
yield from iter_types(arg)
elif is_list(cls) or is_set(cls):
yield List
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_set(cls):
yield Set
arg = type_args(cls)
if arg:
yield from iter_types(arg[0])
elif is_tuple(cls):
yield Tuple
for arg in type_args(cls):
yield from iter_types(arg)
elif is_dict(cls):
yield Dict
arg = type_args(cls)
if arg and len(arg) >= 2:
yield from iter_types(arg[0])
Expand Down Expand Up @@ -242,9 +267,32 @@ def is_union(typ) -> bool:
def is_opt(typ) -> bool:
"""
Test if the type is `typing.Optional`.
>>> is_opt(Optional[int])
True
>>> is_opt(Optional)
True
>>> is_opt(None.__class__)
False
"""
args = type_args(typ)
if args:
return typing_inspect.is_optional_type(typ) and len(args) == 2 and not is_none(args[0]) and is_none(args[1])
else:
return typ is Optional


def is_bare_opt(typ) -> bool:
"""
Test if the type is `typing.Optional` without type args.
>>> is_bare_opt(Optional[int])
False
>>> is_bare_opt(Optional)
True
>>> is_bare_opt(None.__class__)
False
"""
args = get_args(typ)
return typing_inspect.is_optional_type(typ) and len(args) == 2 and not is_none(args[0]) and is_none(args[1])
return not type_args(typ) and typ is Optional


def is_list(typ) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions serde/core.py
Expand Up @@ -455,3 +455,10 @@ def union_func_name(prefix: str, union_args: List[Type]) -> str:
'union_se_int_List_str__IPv4Address'
"""
return re.sub(r"[ ,\[\]]+", "_", f"{prefix}_{'_'.join([typename(e) for e in union_args])}")


def filter_scope(scope: Dict[str, Any]) -> Iterator[str]:
for k, v in scope.items():
if v.__module__ == "typing":
continue
yield k
21 changes: 11 additions & 10 deletions serde/de.py
Expand Up @@ -46,6 +46,7 @@
StrSerializableTypes,
add_func,
fields,
filter_scope,
logger,
raise_unsupported_type,
union_func_name,
Expand Down Expand Up @@ -169,14 +170,11 @@ def wrap(cls: Type):

# Collect types used in the generated code.
for typ in iter_types(cls):
if typ is cls:
if typ is cls or (is_primitive(typ) and not is_enum(typ)):
continue

if typ is Any:
continue

if is_dataclass(typ) or is_enum(typ) or not is_primitive(typ):
scope.types[typ.__name__] = typ
scope.types[typename(typ)] = typ
g[typename(typ)] = typ

# render all union functions
for union in iter_unions(cls):
Expand Down Expand Up @@ -473,7 +471,7 @@ def render(self, arg: DeField) -> str:
if self.custom and not arg.deserializer:
# The function takes a closure in order to execute the default value lazily.
return (
f'serde_custom_class_deserializer({arg.type.__name__}, {arg.datavar}, {arg.data}, '
f'serde_custom_class_deserializer({typename(arg.type)}, {arg.datavar}, {arg.data}, '
f'default=lambda: {res})'
)
else:
Expand Down Expand Up @@ -665,7 +663,7 @@ def {{func}}(data, reuse_instances = {{serde_scope.reuse_instances_default}}):
reuse_instances = {{serde_scope.reuse_instances_default}}
{# List up all classes used by this class. -#}
{% for name in serde_scope.types.keys() -%}
{% for name in serde_scope.types|filter_scope -%}
{{name}} = serde_scope.types['{{name}}']
{% endfor -%}
Expand All @@ -683,6 +681,7 @@ def {{func}}(data, reuse_instances = {{serde_scope.reuse_instances_default}}):
env = jinja2.Environment(loader=jinja2.DictLoader({'iter': template}))
env.filters.update({'rvalue': renderer.render})
env.filters.update({'arg': to_iter_arg})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('iter').render(func=FROM_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))


Expand All @@ -693,7 +692,7 @@ def {{func}}(data, reuse_instances = {{serde_scope.reuse_instances_default}}):
reuse_instances = {{serde_scope.reuse_instances_default}}
{# List up all classes used by this class. #}
{% for name in serde_scope.types.keys() %}
{% for name in serde_scope.types|filter_scope %}
{{name}} = serde_scope.types['{{name}}']
{% endfor %}
Expand All @@ -711,13 +710,14 @@ def {{func}}(data, reuse_instances = {{serde_scope.reuse_instances_default}}):
env = jinja2.Environment(loader=jinja2.DictLoader({'dict': template}))
env.filters.update({'rvalue': renderer.render})
env.filters.update({'arg': functools.partial(to_arg, rename_all=rename_all)})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('dict').render(func=FROM_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))


def render_union_func(cls: Type, union_args: List[Type]) -> str:
template = """
def {{func}}(data, reuse_instances):
{% for name in serde_scope.types.keys() %}
{% for name in serde_scope.types|filter_scope %}
{{name}} = serde_scope.types['{{name}}']
{% endfor %}
Expand Down Expand Up @@ -751,6 +751,7 @@ def {{func}}(data, reuse_instances):
env.filters.update({'rvalue': renderer.render})
env.filters.update({'is_primitive': is_primitive})
env.filters.update({'is_none': is_none})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('dict').render(
func=union_func_name(UNION_DE_PREFIX, union_args),
serde_scope=getattr(cls, SERDE_SCOPE),
Expand Down
31 changes: 18 additions & 13 deletions serde/se.py
Expand Up @@ -18,6 +18,7 @@
SerdeSkip,
is_bare_dict,
is_bare_list,
is_bare_opt,
is_bare_set,
is_bare_tuple,
is_dict,
Expand Down Expand Up @@ -46,6 +47,7 @@
add_func,
conv,
fields,
filter_scope,
is_instance,
logger,
raise_unsupported_type,
Expand Down Expand Up @@ -179,14 +181,11 @@ def wrap(cls: Type):

# Collect types used in the generated code.
for typ in iter_types(cls):
if typ is cls:
if typ is cls or (is_primitive(typ) and not is_enum(typ)):
continue

if typ is Any:
continue

if is_dataclass(typ) or is_enum(typ) or not is_primitive(typ):
scope.types[typ.__name__] = typ
scope.types[typename(typ)] = typ
g[typename(typ)] = typ

# render all union functions
for union in iter_unions(cls):
Expand Down Expand Up @@ -376,7 +375,7 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}},
return copy.deepcopy(obj)
{# List up all classes used by this class. #}
{% for name in serde_scope.types.keys() %}
{% for name in serde_scope.types|filter_scope %}
{{name}} = serde_scope.types['{{name}}']
{% endfor %}
Expand All @@ -392,6 +391,7 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}},
renderer = Renderer(TO_ITER, custom)
env = jinja2.Environment(loader=jinja2.DictLoader({'iter': template}))
env.filters.update({'rvalue': renderer.render})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('iter').render(func=TO_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=sefields(cls))


Expand All @@ -408,7 +408,7 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}},
return copy.deepcopy(obj)
{# List up all classes used by this class. #}
{% for name in serde_scope.types.keys() -%}
{% for name in serde_scope.types|filter_scope -%}
{{name}} = serde_scope.types['{{name}}']
{% endfor -%}
Expand All @@ -432,13 +432,14 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}},
env.filters.update({'rvalue': renderer.render})
env.filters.update({'lvalue': lrenderer.render})
env.filters.update({'case': functools.partial(conv, case=case)})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('dict').render(func=TO_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=sefields(cls))


def render_union_func(cls: Type, union_args: List[Type]) -> str:
template = """
def {{func}}(obj, reuse_instances, convert_sets):
{% for name in serde_scope.types.keys() %}
{% for name in serde_scope.types|filter_scope %}
{{name}} = serde_scope.types['{{name}}']
{% endfor %}
Expand All @@ -456,6 +457,7 @@ def {{func}}(obj, reuse_instances, convert_sets):
env = jinja2.Environment(loader=jinja2.DictLoader({'dict': template}))
env.filters.update({'arg': lambda x: SeField(x, "obj")})
env.filters.update({'rvalue': renderer.render})
env.filters.update({'filter_scope': filter_scope})
return env.get_template('dict').render(
func=union_func_name(UNION_SE_PREFIX, union_args),
serde_scope=getattr(cls, SERDE_SCOPE),
Expand Down Expand Up @@ -567,7 +569,7 @@ def render(self, arg: SeField) -> str:

# Custom field serializer overrides custom class serializer.
if self.custom and not arg.serializer:
return f'serde_custom_class_serializer({arg.type.__name__}, {arg.varname}, default=lambda: {res})'
return f'serde_custom_class_serializer({typename(arg.type)}, {arg.varname}, default=lambda: {res})'
else:
return res

Expand Down Expand Up @@ -598,9 +600,12 @@ def opt(self, arg: SeField) -> str:
"""
Render rvalue for optional.
"""
inner = arg[0]
inner.name = arg.varname
return f'({self.render(inner)}) if {arg.varname} is not None else None'
if is_bare_opt(arg.type):
return f'{arg.varname} if {arg.varname} is not None else None'
else:
inner = arg[0]
inner.name = arg.varname
return f'({self.render(inner)}) if {arg.varname} is not None else None'

def list(self, arg: SeField) -> str:
"""
Expand Down
13 changes: 9 additions & 4 deletions tests/test_compat.py
Expand Up @@ -13,6 +13,7 @@
iter_types,
iter_unions,
type_args,
typename,
union_args,
)
from serde.core import is_instance
Expand Down Expand Up @@ -52,12 +53,16 @@ def test_types():
assert is_dict(dict[str, int])


def test_typename():
assert typename(Optional) == "Optional"


def test_iter_types():
assert [Pri, int, str, float, bool] == list(iter_types(Pri))
assert [str, Pri, int, str, float, bool] == list(iter_types(Dict[str, Pri]))
assert [str] == list(iter_types(List[str]))
assert [int, str, bool, float] == list(iter_types(Tuple[int, str, bool, float]))
assert [PriOpt, int, str, float, bool] == list(iter_types(PriOpt))
assert [Dict, str, Pri, int, str, float, bool] == list(iter_types(Dict[str, Pri]))
assert [List, str] == list(iter_types(List[str]))
assert [Tuple, int, str, bool, float] == list(iter_types(Tuple[int, str, bool, float]))
assert [PriOpt, Optional, int, Optional, str, Optional, float, Optional, bool] == list(iter_types(PriOpt))


def test_iter_unions():
Expand Down
8 changes: 5 additions & 3 deletions tests/test_custom.py
Expand Up @@ -3,7 +3,7 @@
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
from typing import Optional, Union

import pytest

Expand Down Expand Up @@ -93,14 +93,16 @@ class Foo:
i: int
dt1: datetime
dt2: datetime
s: Optional[str] = None
u: Union[str, int] = 10

dt = datetime(2021, 1, 1, 0, 0, 0)
f = Foo(10, dt, dt)

assert to_json(f) == '{"i": 10, "dt1": "01/01/21", "dt2": "01/01/21"}'
assert to_json(f) == '{"i": 10, "dt1": "01/01/21", "dt2": "01/01/21", "s": null, "u": 10}'
assert f == from_json(Foo, to_json(f))

assert to_tuple(f) == (10, '01/01/21', '01/01/21')
assert to_tuple(f) == (10, '01/01/21', '01/01/21', None, 10)
assert f == from_tuple(Foo, to_tuple(f))


Expand Down

0 comments on commit e12e802

Please sign in to comment.