Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bottleneck when Handling Recursive Types #118

Merged
merged 2 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "typical"
packages = [{include = "typic"}]
version = "2.0.24"
version = "2.0.25"
description = "Typical: Python's Typing Toolkit."
authors = ["Sean Stewart <sean_stewart@me.com>"]
license = "MIT"
Expand Down
6 changes: 6 additions & 0 deletions tests/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class B:
a: typing.Optional["A"] = None


@typic.klass
class ABs:
a: typing.Optional[A] = None
bs: typing.Optional[typing.Iterable[B]] = None


@typic.klass
class C:
c: typing.Optional["C"] = None
Expand Down
10 changes: 10 additions & 0 deletions tests/test_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,11 @@ class SubMeta(metaclass=objects.MetaSlotsClass):
{"d": {}, "f": {"g": {"h": "1"}}},
objects.E(objects.D(), objects.F(objects.G(1))),
),
(
objects.ABs,
{"a": {}, "bs": [{}]},
objects.ABs(a=objects.A(), bs=[objects.B()]),
),
],
)
def test_recursive_transmute(annotation, value, expected):
Expand All @@ -842,6 +847,7 @@ def test_recursive_transmute(annotation, value, expected):
(objects.D, {"d": {}}),
(objects.E, {}),
(objects.E, {"d": {}, "f": {"g": {"h": 1}}},),
(objects.ABs, {"a": {}, "bs": [{}]},),
],
)
def test_recursive_validate(annotation, value):
Expand All @@ -863,6 +869,10 @@ def test_recursive_validate(annotation, value):
objects.E(objects.D(), objects.F(objects.G(1))),
{"d": {"d": None}, "f": {"g": {"h": 1}}},
),
(
objects.ABs(a=objects.A(), bs=[objects.B()]),
{"a": {"b": None}, "bs": [{"a": None}]},
),
],
)
def test_recursive_primitive(value, expected):
Expand Down
1 change: 0 additions & 1 deletion typic/constraints/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def for_schema(self, *, with_type: bool = False) -> dict:
minItems=self.min_items,
maxItems=self.max_items,
uniqueItems=self.unique,
items=self.values.for_schema(with_type=True) if self.values else None,
)
if with_type:
schema["type"] = "array"
Expand Down
14 changes: 11 additions & 3 deletions typic/constraints/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import dataclasses
import enum
import reprlib
import sys
import warnings
from inspect import Signature
Expand Down Expand Up @@ -58,10 +59,8 @@ class __AbstractConstraints(abc.ABC):

__slots__ = ("__dict__",)

def __post_init__(self):
self.validator

@util.cached_property
@reprlib.recursive_repr()
def __str(self) -> str:
fields = [f"type={self.type_qualname}"]
for f in dataclasses.fields(self):
Expand Down Expand Up @@ -193,6 +192,9 @@ class BaseConstraints(__AbstractConstraints):
"""
name: Optional[str] = None

def __post_init__(self):
self.validator

def _build_validator(
self, func: gen.Block
) -> Tuple[ChecksT, ContextT]: # pragma: nocover
Expand Down Expand Up @@ -365,6 +367,9 @@ class TypeConstraints(__AbstractConstraints):
"""Whether this constraint can allow null values."""
name: Optional[str] = None

def __post_init__(self):
self.validator

@util.cached_property
def validator(self) -> ValidatorT:
ns = dict(__t=self.type, VT=VT)
Expand Down Expand Up @@ -400,6 +405,9 @@ class EnumConstraints(__AbstractConstraints):
coerce: bool = True
name: Optional[str] = None

def __post_init__(self):
self.validator

@util.cached_property
def __str(self) -> str:
values = (*(x.value for x in self.type),)
Expand Down
42 changes: 20 additions & 22 deletions typic/constraints/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
cached_type_hints,
get_name,
TypeMap,
guard_recursion,
RecursionDetected,
)
from .array import (
Array,
Expand Down Expand Up @@ -107,7 +105,9 @@
]


def _resolve_args(*args, nullable: bool = False) -> Optional[ConstraintsT]:
def _resolve_args(
*args, cls: Type = None, nullable: bool = False
) -> Optional[ConstraintsT]:
largs: List = [*args]
items: List[ConstraintsT] = []

Expand All @@ -116,20 +116,20 @@ def _resolve_args(*args, nullable: bool = False) -> Optional[ConstraintsT]:
if arg in {Any, Ellipsis}:
continue
if origin(arg) is Union:
c = _from_union(arg, nullable=nullable)
c = _from_union(arg, cls=cls, nullable=nullable)
if isinstance(c, MultiConstraints):
items.extend(c.constraints)
else:
items.append(c)
continue
items.append(get_constraints(arg, nullable=nullable))
items.append(_maybe_get_delayed(arg, cls=cls, nullable=nullable))
if len(items) == 1:
return items[0]
return MultiConstraints((*items,)) # type: ignore


def _from_array_type(
t: Type[Array], *, nullable: bool = False, name: str = None
t: Type[Array], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ArrayConstraintsT:
args = get_args(t)
constr_class = cast(
Expand All @@ -138,13 +138,13 @@ def _from_array_type(
# If we don't have args, then return a naive constraint
if not args:
return constr_class(nullable=nullable, name=name)
items = _resolve_args(*args, nullable=nullable)
items = _resolve_args(*args, cls=cls, nullable=nullable)

return constr_class(nullable=nullable, values=items, name=name)


def _from_mapping_type(
t: Type[Mapping], *, nullable: bool = False, name: str = None
t: Type[Mapping], *, nullable: bool = False, name: str = None, cls: Type = None
) -> Union[MappingConstraints, DictConstraints]:
if isbuiltintype(t):
return DictConstraints(nullable=nullable, name=name)
Expand All @@ -157,7 +157,10 @@ def _from_mapping_type(
if not args:
return constr_class(nullable=nullable, name=name)
key_arg, value_arg = args
key_items, value_items = _resolve_args(key_arg), _resolve_args(value_arg)
key_items, value_items = (
_resolve_args(key_arg, cls=cls),
_resolve_args(value_arg, cls=cls),
)
return constr_class(
keys=key_items, values=value_items, nullable=nullable, name=name
)
Expand All @@ -179,7 +182,7 @@ def _from_mapping_type(


def _from_simple_type(
t: Type[SimpleT], *, nullable: bool = False, name: str = None
t: Type[SimpleT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> SimpleConstraintsT:
constr_class = cast(
Type[SimpleConstraintsT], _SIMPLE_CONSTRAINTS.get_by_parent(origin(t))
Expand Down Expand Up @@ -208,13 +211,13 @@ def _resolve_params(


def _from_strict_type(
t: Type[VT], *, nullable: bool = False, name: str = None
t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> TypeConstraints:
return TypeConstraints(t, nullable=nullable, name=name)


def _from_enum_type(
t: Type[enum.Enum], *, nullable: bool = False, name: str = None
t: Type[enum.Enum], *, nullable: bool = False, name: str = None, cls: Type = None
) -> EnumConstraints:
return EnumConstraints(t, nullable=nullable, name=name)

Expand All @@ -239,7 +242,7 @@ def _from_union(


def _from_class(
t: Type[VT], *, nullable: bool = False, name: str = None
t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> Union[ObjectConstraints, TypeConstraints, MappingConstraints]:
if not istypeddict(t) and not isnamedtuple(t) and isbuiltinsubtype(t):
return _from_strict_type(t, nullable=nullable, name=name)
Expand Down Expand Up @@ -322,6 +325,7 @@ def _from_class(
)


@functools.lru_cache(maxsize=None)
def _maybe_get_delayed(
t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
):
Expand All @@ -338,18 +342,12 @@ def _maybe_get_delayed(
return DelayedConstraints(
t, nullable=nullable, name=name, factory=get_constraints # type: ignore
)
with guard_recursion(): # pragma: nocover
try:
return get_constraints(t, nullable=nullable, name=name)
except RecursionDetected:
return DelayedConstraints(
t, nullable=nullable, name=name, factory=get_constraints # type: ignore
)
return get_constraints(t, nullable=nullable, name=name, cls=cls)


@functools.lru_cache(maxsize=None)
def get_constraints(
t: Type[VT], *, nullable: bool = False, name: str = None
t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None
) -> ConstraintsT:
while should_unwrap(t):
nullable = nullable or isoptionaltype(t)
Expand All @@ -365,5 +363,5 @@ def get_constraints(
handler = _from_class
else:
handler = _CONSTRAINT_BUILDER_HANDLERS.get_by_parent(origin(t), _from_class) # type: ignore
c = handler(t, nullable=nullable, name=name) # type: ignore
c = handler(t, nullable=nullable, name=name, cls=cls) # type: ignore
return c
20 changes: 0 additions & 20 deletions typic/constraints/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,26 +329,6 @@ def for_schema(self, *, with_type: bool = False) -> dict:
propertyNames=(
{"pattern": self.key_pattern.pattern} if self.key_pattern else None
),
patternProperties=(
{x: y.for_schema() for x, y in self.patterns.items()}
if self.patterns
else None
),
additionalProperties=(
self.values.for_schema(with_type=True)
if self.values
else not self.total
),
dependencies=(
{
x: y.for_schema(with_type=True)
if isinstance(y, BaseConstraints)
else y
for x, y in self.key_dependencies.items()
}
if self.key_dependencies
else None
),
)
if with_type:
schema["type"] = "object"
Expand Down
4 changes: 3 additions & 1 deletion typic/ext/schema/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ipaddress
import pathlib
import re
import reprlib
import uuid
from typing import (
ClassVar,
Expand Down Expand Up @@ -136,7 +137,7 @@ class BaseSchemaField(_Serializable):
writeOnly: Optional[bool] = None
extensions: Optional[Tuple[frozendict.FrozenDict[str, Any], ...]] = None

__repr = cached_property(filtered_repr)
__repr = cached_property(reprlib.recursive_repr()(filtered_repr))

def __repr__(self) -> str: # pragma: nocover
return self.__repr
Expand Down Expand Up @@ -325,6 +326,7 @@ class ArraySchemaField(BaseSchemaField):
MultiSchemaField,
UndeclaredSchemaField,
NullSchemaField,
Ref,
]
"""A type-alias for the defined JSON Schema Fields."""

Expand Down
Loading