Skip to content

Commit

Permalink
Merge pull request #120 from seandstewart/patch/fix-recursive-type-bo…
Browse files Browse the repository at this point in the history
…ttleneck

Further Improvements to Recursion Preventions
  • Loading branch information
seandstewart committed Jul 25, 2020
2 parents 59fb59d + ff57ed2 commit 17efbff
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "typical"
packages = [{include = "typic"}]
version = "2.0.25"
version = "2.0.26"
description = "Typical: Python's Typing Toolkit."
authors = ["Sean Stewart <sean_stewart@me.com>"]
license = "MIT"
Expand Down
21 changes: 11 additions & 10 deletions typic/serde/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from .des import DesFactory
from .ser import SerFactory
from .translator import TranslatorFactory
from ..util import guard_recursion, RecursionDetected

_T = TypeVar("_T")

Expand All @@ -69,6 +68,7 @@ def __init__(self):
self.translator = TranslatorFactory(self)
self.bind = self.binder.bind
self.__cache = {}
self.__stack = set()
for typ in checks.STDLIB_TYPES:
self.resolve(typ)
self.resolve(Optional[typ])
Expand Down Expand Up @@ -324,17 +324,15 @@ def _get_configuration(self, origin: Type, flags: "SerdeFlags") -> "SerdeConfig"
fields: Dict[
str, Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation]
] = {}
for name, anno in util.cached_type_hints(origin).items():
kwargs = dict(
hints = util.cached_type_hints(origin)
for name, anno in hints.items():
fields[name] = self.annotation(
anno,
flags=dataclasses.replace(flags, fields={}),
default=getattr(origin, name, EMPTY),
namespace=origin,
)
with guard_recursion(): # pragma: nocover
try:
fields[name] = self.annotation(anno, **kwargs)
except RecursionDetected:
fields[name] = self.annotation(anno, recursive=True, **kwargs)

# Filter out any annotations which aren't part of the object's signature.
if flags.signature_only:
fields = {x: fields[x] for x in fields.keys() & params.keys()}
Expand Down Expand Up @@ -400,7 +398,6 @@ def annotation(
flags: "SerdeFlags" = None,
default: Any = EMPTY,
namespace: Type = None,
recursive: bool = False,
) -> Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation]:
"""Get a :py:class:`Annotation` for this type.
Expand Down Expand Up @@ -467,7 +464,8 @@ def annotation(
module=module,
localns=localns,
)
elif use is namespace or recursive:
elif use is namespace or use in self.__stack:
self.__stack.remove(use)
return DelayedAnnotation(
type=use,
resolver=self,
Expand All @@ -478,6 +476,7 @@ def annotation(
flags=flags,
default=default,
)
self.__stack.add(use)
serde = (
self._get_configuration(util.origin(use), flags)
if is_static
Expand Down Expand Up @@ -583,6 +582,7 @@ def resolve(
--------
:py:class:`SerdeProtocol`
"""
self.__stack.clear()
# Extract the meta-data.
anno = self.annotation(
annotation=annotation,
Expand All @@ -594,6 +594,7 @@ def resolve(
namespace=namespace,
)
resolved = self._resolve_from_annotation(anno, _des, _ser, namespace)
self.__stack.clear()
return resolved

@functools.lru_cache(maxsize=None)
Expand Down
3 changes: 2 additions & 1 deletion typic/serde/ser.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def factory(self, annotation: "AnnotationT"):


class DelayedSerializer:
__slots__ = "anno", "factory", "_serializer"
__slots__ = "anno", "factory", "_serializer", "__name__"

def __init__(
self,
Expand All @@ -651,6 +651,7 @@ def __init__(
self.anno = anno
self.factory = factory
self._serializer: Optional[SerializerT] = None
self.__name__ = anno.name

def __call__(self, *args, **kwargs):
if self._serializer is None:
Expand Down
20 changes: 13 additions & 7 deletions typic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,17 @@ def cached_signature(obj: Callable) -> inspect.Signature:


def _safe_get_type_hints(annotation: Union[Type, Callable]) -> Dict[str, Type[Any]]:
raw_annotations = getattr(annotation, "__annotations__", None) or {}
module_name = getattr(annotation, "__module__", None)
if module_name:
base_globals: Optional[Dict[str, Any]] = sys.modules[module_name].__dict__
raw_annotations: Dict[str, Any] = {}
base_globals: Dict[str, Any] = {}
if isinstance(annotation, type):
for base in reversed(annotation.__mro__):
base_globals.update(sys.modules[base.__module__].__dict__)
raw_annotations.update(getattr(base, "__annotations__", None) or {})
else:
base_globals = None
raw_annotations = getattr(annotation, "__annotations__", None) or {}
module_name = getattr(annotation, "__module__", None)
if module_name:
base_globals = sys.modules[module_name].__dict__
annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
Expand All @@ -426,9 +431,9 @@ def _safe_get_type_hints(annotation: Union[Type, Callable]) -> Dict[str, Type[An
else:
value = ForwardRef(value)
try:
value = _eval_type(value, base_globals, None)
value = _eval_type(value, base_globals or None, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
# this is ok, we deal with it later.
pass
annotations[name] = value
return annotations
Expand Down Expand Up @@ -549,6 +554,7 @@ def __init__(self, *args):
def user_call(self, frame, argument_list):
code = frame.f_code
if code in self.stack:
self.stack.clear()
raise RecursionDetected(f"Caught recursion in: {frame}")
self.stack.add(code)

Expand Down

0 comments on commit 17efbff

Please sign in to comment.