Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Further Improvements to Recursion Preventions #120

Merged
merged 2 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "typical"
packages = [{include = "typic"}]
version = "2.0.25"
version = "2.0.26"
description = "Typical: Python's Typing Toolkit."
authors = ["Sean Stewart <sean_stewart@me.com>"]
license = "MIT"
Expand Down
21 changes: 11 additions & 10 deletions typic/serde/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from .des import DesFactory
from .ser import SerFactory
from .translator import TranslatorFactory
from ..util import guard_recursion, RecursionDetected

_T = TypeVar("_T")

Expand All @@ -69,6 +68,7 @@ def __init__(self):
self.translator = TranslatorFactory(self)
self.bind = self.binder.bind
self.__cache = {}
self.__stack = set()
for typ in checks.STDLIB_TYPES:
self.resolve(typ)
self.resolve(Optional[typ])
Expand Down Expand Up @@ -324,17 +324,15 @@ def _get_configuration(self, origin: Type, flags: "SerdeFlags") -> "SerdeConfig"
fields: Dict[
str, Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation]
] = {}
for name, anno in util.cached_type_hints(origin).items():
kwargs = dict(
hints = util.cached_type_hints(origin)
for name, anno in hints.items():
fields[name] = self.annotation(
anno,
flags=dataclasses.replace(flags, fields={}),
default=getattr(origin, name, EMPTY),
namespace=origin,
)
with guard_recursion(): # pragma: nocover
try:
fields[name] = self.annotation(anno, **kwargs)
except RecursionDetected:
fields[name] = self.annotation(anno, recursive=True, **kwargs)

# Filter out any annotations which aren't part of the object's signature.
if flags.signature_only:
fields = {x: fields[x] for x in fields.keys() & params.keys()}
Expand Down Expand Up @@ -400,7 +398,6 @@ def annotation(
flags: "SerdeFlags" = None,
default: Any = EMPTY,
namespace: Type = None,
recursive: bool = False,
) -> Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation]:
"""Get a :py:class:`Annotation` for this type.

Expand Down Expand Up @@ -467,7 +464,8 @@ def annotation(
module=module,
localns=localns,
)
elif use is namespace or recursive:
elif use is namespace or use in self.__stack:
self.__stack.remove(use)
return DelayedAnnotation(
type=use,
resolver=self,
Expand All @@ -478,6 +476,7 @@ def annotation(
flags=flags,
default=default,
)
self.__stack.add(use)
serde = (
self._get_configuration(util.origin(use), flags)
if is_static
Expand Down Expand Up @@ -583,6 +582,7 @@ def resolve(
--------
:py:class:`SerdeProtocol`
"""
self.__stack.clear()
# Extract the meta-data.
anno = self.annotation(
annotation=annotation,
Expand All @@ -594,6 +594,7 @@ def resolve(
namespace=namespace,
)
resolved = self._resolve_from_annotation(anno, _des, _ser, namespace)
self.__stack.clear()
return resolved

@functools.lru_cache(maxsize=None)
Expand Down
3 changes: 2 additions & 1 deletion typic/serde/ser.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def factory(self, annotation: "AnnotationT"):


class DelayedSerializer:
__slots__ = "anno", "factory", "_serializer"
__slots__ = "anno", "factory", "_serializer", "__name__"

def __init__(
self,
Expand All @@ -651,6 +651,7 @@ def __init__(
self.anno = anno
self.factory = factory
self._serializer: Optional[SerializerT] = None
self.__name__ = anno.name

def __call__(self, *args, **kwargs):
if self._serializer is None:
Expand Down
20 changes: 13 additions & 7 deletions typic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,17 @@ def cached_signature(obj: Callable) -> inspect.Signature:


def _safe_get_type_hints(annotation: Union[Type, Callable]) -> Dict[str, Type[Any]]:
raw_annotations = getattr(annotation, "__annotations__", None) or {}
module_name = getattr(annotation, "__module__", None)
if module_name:
base_globals: Optional[Dict[str, Any]] = sys.modules[module_name].__dict__
raw_annotations: Dict[str, Any] = {}
base_globals: Dict[str, Any] = {}
if isinstance(annotation, type):
for base in reversed(annotation.__mro__):
base_globals.update(sys.modules[base.__module__].__dict__)
raw_annotations.update(getattr(base, "__annotations__", None) or {})
else:
base_globals = None
raw_annotations = getattr(annotation, "__annotations__", None) or {}
module_name = getattr(annotation, "__module__", None)
if module_name:
base_globals = sys.modules[module_name].__dict__
annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
Expand All @@ -426,9 +431,9 @@ def _safe_get_type_hints(annotation: Union[Type, Callable]) -> Dict[str, Type[An
else:
value = ForwardRef(value)
try:
value = _eval_type(value, base_globals, None)
value = _eval_type(value, base_globals or None, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
# this is ok, we deal with it later.
pass
annotations[name] = value
return annotations
Expand Down Expand Up @@ -549,6 +554,7 @@ def __init__(self, *args):
def user_call(self, frame, argument_list):
code = frame.f_code
if code in self.stack:
self.stack.clear()
raise RecursionDetected(f"Caught recursion in: {frame}")
self.stack.add(code)

Expand Down