Skip to content

Commit

Permalink
Merge pull request #117 from seandstewart/patch/115/fix-generator-ite…
Browse files Browse the repository at this point in the history
…ration

Fixes for Generator Deserialization and Recursion Detection
  • Loading branch information
seandstewart committed Jul 2, 2020
2 parents 44ff838 + 5e301da commit 6ab0ff0
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 40 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "typical"
packages = [{include = "typic"}]
version = "2.0.23"
version = "2.0.24"
description = "Typical: Python's Typing Toolkit."
authors = ["Sean Stewart <sean_stewart@me.com>"]
license = "MIT"
Expand Down
1 change: 1 addition & 0 deletions tests/test_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_isbuiltintype(obj: typing.Any):
(objects.Dest, objects.Source(), objects.Dest(objects.Source().test)), # type: ignore
(MyClass, factory(), MyClass(1)),
(defaultdict, {}, defaultdict(None)),
(list, (x for x in range(10)), [*range(10)]),
],
ids=objects.get_id,
)
Expand Down
7 changes: 7 additions & 0 deletions typic/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Type,
Tuple,
TypeVar,
Iterable,
)

import typic
Expand Down Expand Up @@ -367,6 +368,12 @@ def isuuidtype(obj: Type[ObjectT]) -> bool:
_COLLECTIONS = {list, set, tuple, frozenset, dict, str, bytes}


@functools.lru_cache(maxsize=None)
def isiterabletype(obj: Type[ObjectT]):
obj = util.origin(obj)
return _issubclass(obj, Iterable)


@functools.lru_cache(maxsize=None)
def iscollectiontype(obj: Type[ObjectT]):
"""Test whether this annotation is a subclass of :py:class:`typing.Collection`.
Expand Down
10 changes: 10 additions & 0 deletions typic/constraints/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ def __getattr__(self, item):


class ForwardDelayedConstraints:
__slots__ = (
"ref",
"module",
"localns",
"nullable",
"name",
"factory",
"_constraints",
)

def __init__(
self,
ref: ForwardRef,
Expand Down
13 changes: 10 additions & 3 deletions typic/constraints/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
cached_type_hints,
get_name,
TypeMap,
recursing,
guard_recursion,
RecursionDetected,
)
from .array import (
Array,
Expand Down Expand Up @@ -333,11 +334,17 @@ def _maybe_get_delayed(
name=name, # type: ignore
factory=get_constraints,
)
elif t is cls or recursing():
elif t is cls:
return DelayedConstraints(
t, nullable=nullable, name=name, factory=get_constraints # type: ignore
)
return get_constraints(t, nullable=nullable, name=name)
with guard_recursion(): # pragma: nocover
try:
return get_constraints(t, nullable=nullable, name=name)
except RecursionDetected:
return DelayedConstraints(
t, nullable=nullable, name=name, factory=get_constraints # type: ignore
)


@functools.lru_cache(maxsize=None)
Expand Down
8 changes: 5 additions & 3 deletions typic/klass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typic.api import wrap_cls, ObjectT
from typic.types import freeze
from .serde.common import SerdeFlags
from typic.util import slotted, recursing
from typic.util import slotted, guard_recursion, RecursionDetected

_field_slots: Tuple[str, ...] = cast(Tuple[str, ...], dataclasses.Field.__slots__) + (
"exclude",
Expand Down Expand Up @@ -145,13 +145,15 @@ def make_typedclass(
frozen=frozen,
)
if slots:
if recursing():
try:
with guard_recursion():
dcls = slotted(dcls)
except RecursionDetected:
raise TypeError(
f"{cls!r} uses a custom metaclass {cls.__class__!r} "
"which is not compatible with the 'slots' operator. "
"See Issue #104 on GitHub for more information."
) from None
dcls = slotted(dcls)
fields = [
f if isinstance(f, Field) else Field.from_field(f)
for f in dataclasses.fields(dcls)
Expand Down
34 changes: 12 additions & 22 deletions typic/serde/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .des import DesFactory
from .ser import SerFactory
from .translator import TranslatorFactory
from ..util import guard_recursion, RecursionDetected

_T = TypeVar("_T")

Expand Down Expand Up @@ -312,15 +313,6 @@ def tojson(
proto: SerdeProtocol = self._get_serializer_proto(t)
return proto.tojson(obj, indent=indent, ensure_ascii=ensure_ascii, **kwargs)

@staticmethod
def _self_referencing(ref: Union[Type, ForwardRef], origin: Type):
args = util.get_args(ref) or (ref,)
for arg in args:
recursive = arg is origin
if recursive:
return recursive
return False

@functools.lru_cache(maxsize=None)
def _get_configuration(self, origin: Type, flags: "SerdeFlags") -> "SerdeConfig":
if hasattr(origin, SERDE_FLAGS_ATTR):
Expand All @@ -332,18 +324,16 @@ def _get_configuration(self, origin: Type, flags: "SerdeFlags") -> "SerdeConfig"
str, Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation]
] = {}
for name, anno in util.cached_type_hints(origin).items():
namespace: Optional[Type] = origin
recursive = True
if self._self_referencing(anno, origin) and not util.recursing():
namespace = None
recursive = False
fields[name] = self.annotation(
anno,
kwargs = dict(
flags=dataclasses.replace(flags, fields={}),
default=getattr(origin, name, EMPTY),
namespace=namespace,
recursive=recursive,
namespace=origin,
)
with guard_recursion(): # pragma: nocover
try:
fields[name] = self.annotation(anno, **kwargs)
except RecursionDetected:
fields[name] = self.annotation(anno, recursive=True, **kwargs)
# Filter out any annotations which aren't part of the object's signature.
if flags.signature_only:
fields = {x: fields[x] for x in fields.keys() & params.keys()}
Expand Down Expand Up @@ -374,10 +364,10 @@ def _get_configuration(self, origin: Type, flags: "SerdeFlags") -> "SerdeConfig"
anno = fields[name]
default = anno.parameter.default if anno.parameter else EMPTY
if isinstance(anno, ForwardDelayedAnnotation):
if {
util.get_name(anno.ref),
util.get_name(default),
} & type_name_omissions:
if (
not {util.get_name(anno.ref), util.get_name(default)}
& type_name_omissions
):
fields_out_final[name] = out
elif not {anno.origin, default} & type_omissions:
fields_out_final[name] = out
Expand Down
4 changes: 2 additions & 2 deletions typic/serde/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Any,
)

from typic.checks import iscollectiontype, ismappingtype
from typic.checks import ismappingtype, isiterabletype
from typic.gen import Block, Keyword
from typic.util import (
cached_type_hints,
Expand Down Expand Up @@ -132,7 +132,7 @@ def iterator(self, type: Type, values: bool = False) -> "FieldIteratorT":
iter = _valuescaller if values else _itemscaller
return iter

if iscollectiontype(type):
if isiterabletype(type):
return _iter

fields = self.get_fields(type, as_source=True) or {}
Expand Down
55 changes: 46 additions & 9 deletions typic/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import ast
import bdb
import collections
import contextlib
import dataclasses
import functools
import inspect
Expand Down Expand Up @@ -528,15 +530,50 @@ def get_by_parent(self, t: Type, default: VT = None) -> Optional[VT]:
return default


def recursing() -> bool:
"""Detect whether we're in a recursive loop."""
frames = inspect.getouterframes(inspect.currentframe())[1:]
top_frame = inspect.getframeinfo(frames[0][0])
for frame, _, _, _, _, _ in frames[1:]:
(path, line_number, func_name, lines, index) = inspect.getframeinfo(frame)
if path == top_frame[0] and func_name == top_frame[2]:
return True
return False
class RecursionDetected(RuntimeError):
...


class RecursionDetector(bdb.Bdb): # pragma: nocover
"""Prevent recursion from even starting.
https://stackoverflow.com/a/36663046
Warnings
--------
While the detector is tracing, no other debug tracers (i.e., codecov!) can trace.
"""

def do_clear(self, arg):
pass

def __init__(self, *args):
bdb.Bdb.__init__(self, *args)
self.stack = set()

def user_call(self, frame, argument_list):
code = frame.f_code
if code in self.stack:
raise RecursionDetected(f"Caught recursion in: {frame}")
self.stack.add(code)

def user_return(self, frame, return_value):
if frame.f_code in self.stack:
self.stack.remove(frame.f_code)


_detector = RecursionDetector()


@contextlib.contextmanager
def guard_recursion(): # pragma: nocover
curtrace = sys.gettrace()
_detector.set_trace()
try:
yield
finally:
_detector.stack.clear()
sys.settrace(curtrace)


def slotted(
Expand Down

0 comments on commit 6ab0ff0

Please sign in to comment.