diff --git a/pyproject.toml b/pyproject.toml index f301c039..30589f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "typical" packages = [{include = "typic"}] -version = "2.0.24" +version = "2.0.25" description = "Typical: Python's Typing Toolkit." authors = ["Sean Stewart "] license = "MIT" diff --git a/tests/objects.py b/tests/objects.py index 0a0c6e6d..3a7af036 100644 --- a/tests/objects.py +++ b/tests/objects.py @@ -87,6 +87,12 @@ class B: a: typing.Optional["A"] = None +@typic.klass +class ABs: + a: typing.Optional[A] = None + bs: typing.Optional[typing.Iterable[B]] = None + + @typic.klass class C: c: typing.Optional["C"] = None diff --git a/tests/test_typed.py b/tests/test_typed.py index b390fa03..cfbc6308 100644 --- a/tests/test_typed.py +++ b/tests/test_typed.py @@ -823,6 +823,11 @@ class SubMeta(metaclass=objects.MetaSlotsClass): {"d": {}, "f": {"g": {"h": "1"}}}, objects.E(objects.D(), objects.F(objects.G(1))), ), + ( + objects.ABs, + {"a": {}, "bs": [{}]}, + objects.ABs(a=objects.A(), bs=[objects.B()]), + ), ], ) def test_recursive_transmute(annotation, value, expected): @@ -842,6 +847,7 @@ def test_recursive_transmute(annotation, value, expected): (objects.D, {"d": {}}), (objects.E, {}), (objects.E, {"d": {}, "f": {"g": {"h": 1}}},), + (objects.ABs, {"a": {}, "bs": [{}]},), ], ) def test_recursive_validate(annotation, value): @@ -863,6 +869,10 @@ def test_recursive_validate(annotation, value): objects.E(objects.D(), objects.F(objects.G(1))), {"d": {"d": None}, "f": {"g": {"h": 1}}}, ), + ( + objects.ABs(a=objects.A(), bs=[objects.B()]), + {"a": {"b": None}, "bs": [{"a": None}]}, + ), ], ) def test_recursive_primitive(value, expected): diff --git a/typic/constraints/array.py b/typic/constraints/array.py index 5e5c1864..e8c15107 100644 --- a/typic/constraints/array.py +++ b/typic/constraints/array.py @@ -162,7 +162,6 @@ def for_schema(self, *, with_type: bool = False) -> dict: minItems=self.min_items, maxItems=self.max_items, uniqueItems=self.unique, - items=self.values.for_schema(with_type=True) if self.values else None, ) if with_type: schema["type"] = "array" diff --git a/typic/constraints/common.py b/typic/constraints/common.py index 90311221..5b30627d 100644 --- a/typic/constraints/common.py +++ b/typic/constraints/common.py @@ -3,6 +3,7 @@ import abc import dataclasses import enum +import reprlib import sys import warnings from inspect import Signature @@ -58,10 +59,8 @@ class __AbstractConstraints(abc.ABC): __slots__ = ("__dict__",) - def __post_init__(self): - self.validator - @util.cached_property + @reprlib.recursive_repr() def __str(self) -> str: fields = [f"type={self.type_qualname}"] for f in dataclasses.fields(self): @@ -193,6 +192,9 @@ class BaseConstraints(__AbstractConstraints): """ name: Optional[str] = None + def __post_init__(self): + self.validator + def _build_validator( self, func: gen.Block ) -> Tuple[ChecksT, ContextT]: # pragma: nocover @@ -365,6 +367,9 @@ class TypeConstraints(__AbstractConstraints): """Whether this constraint can allow null values.""" name: Optional[str] = None + def __post_init__(self): + self.validator + @util.cached_property def validator(self) -> ValidatorT: ns = dict(__t=self.type, VT=VT) @@ -400,6 +405,9 @@ class EnumConstraints(__AbstractConstraints): coerce: bool = True name: Optional[str] = None + def __post_init__(self): + self.validator + @util.cached_property def __str(self) -> str: values = (*(x.value for x in self.type),) diff --git a/typic/constraints/factory.py b/typic/constraints/factory.py index 982bc44b..436f366d 100644 --- a/typic/constraints/factory.py +++ b/typic/constraints/factory.py @@ -41,8 +41,6 @@ cached_type_hints, get_name, TypeMap, - guard_recursion, - RecursionDetected, ) from .array import ( Array, @@ -107,7 +105,9 @@ ] -def _resolve_args(*args, nullable: bool = False) -> Optional[ConstraintsT]: +def _resolve_args( + *args, cls: Type = None, nullable: bool = False +) -> Optional[ConstraintsT]: largs: List = [*args] items: List[ConstraintsT] = [] @@ -116,20 +116,20 @@ def _resolve_args(*args, nullable: bool = False) -> Optional[ConstraintsT]: if arg in {Any, Ellipsis}: continue if origin(arg) is Union: - c = _from_union(arg, nullable=nullable) + c = _from_union(arg, cls=cls, nullable=nullable) if isinstance(c, MultiConstraints): items.extend(c.constraints) else: items.append(c) continue - items.append(get_constraints(arg, nullable=nullable)) + items.append(_maybe_get_delayed(arg, cls=cls, nullable=nullable)) if len(items) == 1: return items[0] return MultiConstraints((*items,)) # type: ignore def _from_array_type( - t: Type[Array], *, nullable: bool = False, name: str = None + t: Type[Array], *, nullable: bool = False, name: str = None, cls: Type = None ) -> ArrayConstraintsT: args = get_args(t) constr_class = cast( @@ -138,13 +138,13 @@ def _from_array_type( # If we don't have args, then return a naive constraint if not args: return constr_class(nullable=nullable, name=name) - items = _resolve_args(*args, nullable=nullable) + items = _resolve_args(*args, cls=cls, nullable=nullable) return constr_class(nullable=nullable, values=items, name=name) def _from_mapping_type( - t: Type[Mapping], *, nullable: bool = False, name: str = None + t: Type[Mapping], *, nullable: bool = False, name: str = None, cls: Type = None ) -> Union[MappingConstraints, DictConstraints]: if isbuiltintype(t): return DictConstraints(nullable=nullable, name=name) @@ -157,7 +157,10 @@ def _from_mapping_type( if not args: return constr_class(nullable=nullable, name=name) key_arg, value_arg = args - key_items, value_items = _resolve_args(key_arg), _resolve_args(value_arg) + key_items, value_items = ( + _resolve_args(key_arg, cls=cls), + _resolve_args(value_arg, cls=cls), + ) return constr_class( keys=key_items, values=value_items, nullable=nullable, name=name ) @@ -179,7 +182,7 @@ def _from_mapping_type( def _from_simple_type( - t: Type[SimpleT], *, nullable: bool = False, name: str = None + t: Type[SimpleT], *, nullable: bool = False, name: str = None, cls: Type = None ) -> SimpleConstraintsT: constr_class = cast( Type[SimpleConstraintsT], _SIMPLE_CONSTRAINTS.get_by_parent(origin(t)) @@ -208,13 +211,13 @@ def _resolve_params( def _from_strict_type( - t: Type[VT], *, nullable: bool = False, name: str = None + t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None ) -> TypeConstraints: return TypeConstraints(t, nullable=nullable, name=name) def _from_enum_type( - t: Type[enum.Enum], *, nullable: bool = False, name: str = None + t: Type[enum.Enum], *, nullable: bool = False, name: str = None, cls: Type = None ) -> EnumConstraints: return EnumConstraints(t, nullable=nullable, name=name) @@ -239,7 +242,7 @@ def _from_union( def _from_class( - t: Type[VT], *, nullable: bool = False, name: str = None + t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None ) -> Union[ObjectConstraints, TypeConstraints, MappingConstraints]: if not istypeddict(t) and not isnamedtuple(t) and isbuiltinsubtype(t): return _from_strict_type(t, nullable=nullable, name=name) @@ -322,6 +325,7 @@ def _from_class( ) +@functools.lru_cache(maxsize=None) def _maybe_get_delayed( t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None ): @@ -338,18 +342,12 @@ def _maybe_get_delayed( return DelayedConstraints( t, nullable=nullable, name=name, factory=get_constraints # type: ignore ) - with guard_recursion(): # pragma: nocover - try: - return get_constraints(t, nullable=nullable, name=name) - except RecursionDetected: - return DelayedConstraints( - t, nullable=nullable, name=name, factory=get_constraints # type: ignore - ) + return get_constraints(t, nullable=nullable, name=name, cls=cls) @functools.lru_cache(maxsize=None) def get_constraints( - t: Type[VT], *, nullable: bool = False, name: str = None + t: Type[VT], *, nullable: bool = False, name: str = None, cls: Type = None ) -> ConstraintsT: while should_unwrap(t): nullable = nullable or isoptionaltype(t) @@ -365,5 +363,5 @@ def get_constraints( handler = _from_class else: handler = _CONSTRAINT_BUILDER_HANDLERS.get_by_parent(origin(t), _from_class) # type: ignore - c = handler(t, nullable=nullable, name=name) # type: ignore + c = handler(t, nullable=nullable, name=name, cls=cls) # type: ignore return c diff --git a/typic/constraints/mapping.py b/typic/constraints/mapping.py index 78b5d252..7bb3669a 100644 --- a/typic/constraints/mapping.py +++ b/typic/constraints/mapping.py @@ -329,26 +329,6 @@ def for_schema(self, *, with_type: bool = False) -> dict: propertyNames=( {"pattern": self.key_pattern.pattern} if self.key_pattern else None ), - patternProperties=( - {x: y.for_schema() for x, y in self.patterns.items()} - if self.patterns - else None - ), - additionalProperties=( - self.values.for_schema(with_type=True) - if self.values - else not self.total - ), - dependencies=( - { - x: y.for_schema(with_type=True) - if isinstance(y, BaseConstraints) - else y - for x, y in self.key_dependencies.items() - } - if self.key_dependencies - else None - ), ) if with_type: schema["type"] = "object" diff --git a/typic/ext/schema/field.py b/typic/ext/schema/field.py index 16e6c12e..da94f545 100644 --- a/typic/ext/schema/field.py +++ b/typic/ext/schema/field.py @@ -8,6 +8,7 @@ import ipaddress import pathlib import re +import reprlib import uuid from typing import ( ClassVar, @@ -136,7 +137,7 @@ class BaseSchemaField(_Serializable): writeOnly: Optional[bool] = None extensions: Optional[Tuple[frozendict.FrozenDict[str, Any], ...]] = None - __repr = cached_property(filtered_repr) + __repr = cached_property(reprlib.recursive_repr()(filtered_repr)) def __repr__(self) -> str: # pragma: nocover return self.__repr @@ -325,6 +326,7 @@ class ArraySchemaField(BaseSchemaField): MultiSchemaField, UndeclaredSchemaField, NullSchemaField, + Ref, ] """A type-alias for the defined JSON Schema Fields.""" diff --git a/typic/ext/schema/schema.py b/typic/ext/schema/schema.py index 52917451..7736191c 100644 --- a/typic/ext/schema/schema.py +++ b/typic/ext/schema/schema.py @@ -12,17 +12,17 @@ List, Generic, cast, - Set, AnyStr, + TYPE_CHECKING, ) import inflection # type: ignore from typic.common import ReadOnly, WriteOnly from typic.serde.resolver import resolver -from typic.serde.common import SerdeProtocol, Annotation, DelayedSerdeProtocol -from typic.compat import Final, TypedDict -from typic.util import get_args, origin +from typic.serde.common import SerdeProtocol, DelayedSerdeProtocol +from typic.compat import Final, TypedDict, ForwardRef +from typic.util import get_args, origin, get_name from typic.checks import istypeddict, isnamedtuple from typic.types.frozendict import FrozenDict @@ -33,11 +33,13 @@ Ref, SchemaFieldT, SCHEMA_FIELD_FORMATS, - get_field_type, SchemaType, ArraySchemaField, ) +if TYPE_CHECKING: + from typic.constraints import ArrayConstraints, MappingConstraints # noqa: F401 + _IGNORE_DOCS = frozenset({Mapping.__doc__, Generic.__doc__, List.__doc__}) __all__ = ("SchemaBuilder", "SchemaDefinitions", "builder") @@ -73,52 +75,92 @@ def __init__(self): def attach(self, t: Type): self.__attached.add(t) - def _handle_mapping(self, anno: Annotation, constraints: dict, *, name: str = None): + def _handle_mapping( + self, proto: "SerdeProtocol", parent: Type = None, *, name: str = None, **extra + ) -> Mapping: + anno = proto.annotation args = anno.args - constraints["title"] = self.defname(anno.resolved, name=name) + config = extra + config["title"] = self.defname(anno.resolved, name=name) doc = getattr(anno.resolved, "__doc__", None) if doc not in _IGNORE_DOCS: - constraints["description"] = doc - field: Optional[SchemaFieldT] = None + config["description"] = doc + + constraints = cast("MappingConstraints", proto.constraints) + if constraints.items: + config["items"] = { + nm: self.get_field( + resolver.resolve( + it.type, namespace=parent, is_optional=it.nullable + ), + parent=parent, + ) + for nm, it in constraints.items.items() + } + if constraints.patterns: + config["patternProperties"] = { + p: self.get_field( + resolver.resolve( + it.type, namespace=parent, is_optional=it.nullable + ), + parent=parent, + ) + for p, it in constraints.patterns.items() + } + if constraints.key_dependencies: + config["dependencies"] = { + k: it + if isinstance(it, tuple) + else self.get_field( + resolver.resolve( + it.type, namespace=parent, is_optional=it.nullable + ), + parent=parent, + ) + for k, it in constraints.key_dependencies.items() + } + config["additionalProperties"] = not constraints.total if args: - field = self.get_field(resolver.resolve(args[-1])) - if "additionalProperties" in constraints: - other = constraints["additionalProperties"] - # this is coming in from a constraint - if isinstance(other, dict): - schema_type = other.pop("type", None) - field = field or get_field_type(schema_type)() - if isinstance(field, MultiSchemaField): - for k in {"oneOf", "anyOf", "allOf"} & other.keys(): - other[k] = tuple( - get_field_type(x.pop("type"))(**x) for x in other[k] - ) - field = dataclasses.replace(field, **other) - constraints["additionalProperties"] = field - - def _handle_array(self, anno: Annotation, constraints: dict): + config["additionalProperties"] = self.get_field( + resolver.resolve(args[-1], namespace=parent), parent=parent + ) + elif constraints.values: + config["additionalProperties"] = self.get_field( + resolver.resolve( + constraints.values.type, + is_optional=constraints.values.nullable, + namespace=parent, + ), + parent=parent, + ) + return config + + def _handle_array( + self, proto: "SerdeProtocol", parent: Type = None, **extra + ) -> Mapping: + anno = proto.annotation args = anno.args has_ellipsis = args[-1] is Ellipsis if args else False + config = extra if has_ellipsis: args = args[:-1] + constraints = cast("ArrayConstraints", proto.constraints) if args: - constrs = set(self.get_field(resolver.resolve(x)) for x in args) - constraints["items"] = (*constrs,) if len(constrs) > 1 else constrs.pop() + constrs = set( + self.get_field(resolver.resolve(x, namespace=parent), parent=parent) + for x in args + ) + config["items"] = (*constrs,) if len(constrs) > 1 else constrs.pop() if anno.origin in {tuple, frozenset}: - constraints["additionalItems"] = False if not has_ellipsis else None + config["additionalItems"] = False if not has_ellipsis else None if anno.origin in {set, frozenset}: - constraints["uniqueItems"] = True - elif "items" in constraints: - items: dict = constraints["items"] - multi_keys: Set[str] = {"oneOf", "anyOf", "allOf"} & items.keys() - if multi_keys: - for k in multi_keys: - items[k] = tuple( - get_field_type(x.pop("type"))(**x) for x in items[k] - ) - constraints["items"] = MultiSchemaField(**items) - else: - constraints["items"] = get_field_type(items.pop("type"))(**items) + config["uniqueItems"] = True + elif constraints.values: + config["items"] = self.get_field( + resolver.resolve(constraints.values.type, namespace=parent), + parent=parent, + ) + return config def get_field( self, @@ -127,6 +169,7 @@ def get_field( ro: bool = None, wo: bool = None, name: str = None, + parent: Type = None, ) -> "SchemaFieldT": """Get a field definition for a JSON Schema.""" anno = protocol.annotation @@ -151,12 +194,19 @@ def get_field( # {'oneOf': [{'type': 'string'}, {'type': 'integer'}]} # We don't care about syntactic sugar if it's functionally the same. if use is Union: + fields: List[SchemaFieldT] = [] + args = get_args(anno.un_resolved) + for t in args: + if t.__class__ is ForwardRef or t is parent: + n = name or get_name(t) + fields.append(Ref(f"#/definitions/{n}")) + continue + fields.append( + self.get_field(resolver.resolve(t, namespace=parent), parent=parent) + ) schema = MultiSchemaField( title=self.defname(anno.resolved, name=name) if name else None, - anyOf=tuple( - self.get_field(resolver.resolve(x)) - for x in get_args(anno.un_resolved) - ), + anyOf=(*fields,), ) self.__cache[anno] = schema return schema @@ -191,17 +241,17 @@ def get_field( else: base = cast(SchemaFieldT, SCHEMA_FIELD_FORMATS.get_by_parent(use)) if base: - constraints = ( - protocol.constraints.for_schema() if protocol.constraints else {} - ) - constraints.update(enum=enum_, default=default, readOnly=ro, writeOnly=wo) + config = protocol.constraints.for_schema() if protocol.constraints else {} + config.update(enum=enum_, default=default, readOnly=ro, writeOnly=wo) # `use` should always be a dict if the annotation is a Mapping, # thanks to `origin()` & `resolve()`. if isinstance(base, ObjectSchemaField): - self._handle_mapping(anno, constraints, name=name) + config = self._handle_mapping( + protocol, parent=parent, name=name, **config + ) elif isinstance(base, ArraySchemaField): - self._handle_array(anno, constraints) - schema = dataclasses.replace(base, **constraints) + config = self._handle_array(protocol, parent=parent, **config) + schema = dataclasses.replace(base, **config) else: try: schema = self.build_schema(use, name=self.defname(use, name=name)) @@ -285,7 +335,7 @@ def build_schema(self, obj: Type, *, name: str = None) -> "ObjectSchemaField": elif protocol.annotation.resolved_origin is obj: flattened = Ref(f"#/definitions/{self.defname(obj)}") else: - field = self.get_field(protocol, name=nm) + field = self.get_field(protocol, name=nm, parent=obj) # If we received an object schema, # figure out a name and inherit the definitions. flattened = self._flatten_definitions(definitions, field) diff --git a/typic/klass.py b/typic/klass.py index 134b9a81..d4fbdd9f 100644 --- a/typic/klass.py +++ b/typic/klass.py @@ -146,7 +146,7 @@ def make_typedclass( ) if slots: try: - with guard_recursion(): + with guard_recursion(): # pragma: nocover dcls = slotted(dcls) except RecursionDetected: raise TypeError( diff --git a/typic/serde/common.py b/typic/serde/common.py index 65cfe0d2..5a560660 100644 --- a/typic/serde/common.py +++ b/typic/serde/common.py @@ -1,5 +1,6 @@ import dataclasses import inspect +import reprlib import sys import warnings from typing import ( @@ -222,6 +223,19 @@ class ForwardDelayedAnnotation: _name: Optional[str] = None _resolved: Optional["SerdeProtocol"] = dataclasses.field(default=None) + @reprlib.recursive_repr() + def __repr__(self): + return ( + f"{self.__class__}(" + f"ref={self.ref}," + f"module={self.module}!r, " + f"parameter={self.parameter}, " + f"is_optional={self.is_optional}, " + f"is_strict={self.is_strict}, " + f"flags={self.flags}, " + f"default={self.default})" + ) + @property def resolved(self): if self._resolved is None: @@ -230,7 +244,7 @@ def resolved(self): type = evaluate_forwardref(self.ref, globalns or {}, self.localns or {}) except NameError as e: warnings.warn( - f"Counldn't resolve forward reference: {e}. " + f"Couldn't resolve forward reference: {e}. " f"Make sure this type is available in {self.module}." ) type = Any @@ -264,6 +278,17 @@ class DelayedAnnotation: _name: Optional[str] = None _resolved: Optional["SerdeProtocol"] = dataclasses.field(default=None) + @reprlib.recursive_repr() + def __repr__(self): + return ( + f"{self.__class__}(" + f"parameter={self.parameter}, " + f"is_optional={self.is_optional}, " + f"is_strict={self.is_strict}, " + f"flags={self.flags}, " + f"default={self.default})" + ) + @property def resolved(self): if self._resolved is None: diff --git a/typic/serde/des.py b/typic/serde/des.py index 3e1dfa49..296be6b8 100644 --- a/typic/serde/des.py +++ b/typic/serde/des.py @@ -336,11 +336,14 @@ def _build_typeddict_des( annotation: "Annotation", *, total: bool = True, + namespace: Type = None, ): with func.b(f"if issubclass({self.VTYPE}, Mapping):", Mapping=abc.Mapping) as b: fields_deser = { - x: self.resolver._resolve_from_annotation(y).transmute + x: self.resolver._resolve_from_annotation( + y, _namespace=namespace + ).transmute for x, y in annotation.serde.fields.items() } x = "fields_in[x]" @@ -361,11 +364,15 @@ def _build_typeddict_des( ) def _build_typedtuple_des( - self, func: gen.Block, anno_name: str, annotation: "Annotation" + self, + func: gen.Block, + anno_name: str, + annotation: "Annotation", + namespace: Type = None, ): with func.b(f"if issubclass({self.VTYPE}, Mapping):", Mapping=abc.Mapping) as b: if annotation.serde.fields: - self._build_typeddict_des(b, anno_name, annotation) + self._build_typeddict_des(b, anno_name, annotation, namespace=namespace) else: b.l(f"{self.VNAME} = {anno_name}(**{self.VNAME})",) with func.b( @@ -385,15 +392,23 @@ def _build_typedtuple_des( ) def _build_mapping_des( - self, func: gen.Block, anno_name: str, annotation: "Annotation", + self, + func: gen.Block, + anno_name: str, + annotation: "Annotation", + namespace: Type = None, ): key_des, item_des = None, None args = annotation.args if args: args = cast(Tuple[Type, Type], args) key_type, item_type = args - key_des = self.resolver.resolve(key_type, flags=annotation.serde.flags) - item_des = self.resolver.resolve(item_type, flags=annotation.serde.flags) + key_des = self.resolver.resolve( + key_type, flags=annotation.serde.flags, namespace=namespace + ) + item_des = self.resolver.resolve( + item_type, flags=annotation.serde.flags, namespace=namespace + ) if issubclass(annotation.resolved_origin, defaultdict): factory = self._get_default_factory(annotation) func.namespace[anno_name] = functools.partial(defaultdict, factory) @@ -431,7 +446,11 @@ def _build_mapping_des( ) def _build_collection_des( - self, func: gen.Block, anno_name: str, annotation: "Annotation" + self, + func: gen.Block, + anno_name: str, + annotation: "Annotation", + namespace: Type = None, ): item_des = None it_name = f"{anno_name}_item_des" @@ -439,7 +458,9 @@ def _build_collection_des( line = f"{self.VNAME} = {anno_name}({iterate})" if annotation.args: item_type = annotation.args[0] - item_des = self.resolver.resolve(item_type, flags=annotation.serde.flags) + item_des = self.resolver.resolve( + item_type, flags=annotation.serde.flags, namespace=namespace + ) line = ( f"{self.VNAME} = " f"{anno_name}({it_name}(x) for x in parent({iterate}))" @@ -463,7 +484,11 @@ def _build_path_des( func.l(f"{self.VNAME} = {anno_name}({self.VNAME})") def _build_generic_des( - self, func: gen.Block, anno_name: str, annotation: "Annotation" + self, + func: gen.Block, + anno_name: str, + annotation: "Annotation", + namespace: Type = None, ): serde = annotation.serde resolved = annotation.resolved @@ -501,7 +526,7 @@ def happypath(k, v, **ns): if serde.fields and len(matched) == len(serde.fields_in): desers = { f: self.resolver._resolve_from_annotation( - serde.fields[f] + serde.fields[f], _namespace=namespace ).transmute for f in matched } @@ -530,7 +555,9 @@ def happypath(k, v, **ns): translate=self.resolver.translate, ) - def _build_des(self, annotation: "Annotation", func_name: str) -> Callable: + def _build_des( + self, annotation: "Annotation", func_name: str, namespace: Type = None + ) -> Callable: args = annotation.args # Get the "origin" of the annotation. # For natives and their typing.* equivs, this will be a builtin type. @@ -564,18 +591,30 @@ def _build_des(self, annotation: "Annotation", func_name: str) -> Callable: self._build_builtin_des(func, anno_name, annotation) elif checks.istypeddict(origin): self._build_typeddict_des( - func, anno_name, annotation, total=origin.__total__ # type: ignore + func, + anno_name, + annotation, + total=origin.__total__, # type: ignore + namespace=namespace, ) elif checks.istypedtuple(origin) or checks.isnamedtuple(origin): - self._build_typedtuple_des(func, anno_name, annotation) + self._build_typedtuple_des( + func, anno_name, annotation, namespace=namespace + ) elif not args and checks.isbuiltinsubtype(origin): self._build_builtin_des(func, anno_name, annotation) elif checks.ismappingtype(origin): - self._build_mapping_des(func, anno_name, annotation) + self._build_mapping_des( + func, anno_name, annotation, namespace=namespace + ) elif checks.iscollectiontype(origin): - self._build_collection_des(func, anno_name, annotation) + self._build_collection_des( + func, anno_name, annotation, namespace=namespace + ) else: - self._build_generic_des(func, anno_name, annotation) + self._build_generic_des( + func, anno_name, annotation, namespace=namespace + ) func.l(f"{gen.Keyword.RET} {self.VNAME}") deserializer = main.compile(ns=ns, name=func_name) return deserializer @@ -645,7 +684,10 @@ def des(val: Any, *, __d=__d, __v=validator) -> ObjectT: return des, validator def factory( - self, annotation: "Annotation", constr: Optional["const.ConstraintsT"] = None + self, + annotation: "Annotation", + constr: Optional["const.ConstraintsT"] = None, + namespace: Type = None, ) -> Tuple[DeserializerT, "const.ValidatorT"]: annotation.serde = annotation.serde or SerdeConfig() key = self._get_name(annotation, constr) @@ -657,7 +699,7 @@ def factory( deserializer = des break if not deserializer: - deserializer = self._build_des(annotation, key) + deserializer = self._build_des(annotation, key, namespace) deserializer, validator = self._finalize_deserializer( annotation, deserializer, constr diff --git a/typic/serde/resolver.py b/typic/serde/resolver.py index 8a548501..20f4c94c 100644 --- a/typic/serde/resolver.py +++ b/typic/serde/resolver.py @@ -68,6 +68,7 @@ def __init__(self): self.binder = Binder(self) self.translator = TranslatorFactory(self) self.bind = self.binder.bind + self.__cache = {} for typ in checks.STDLIB_TYPES: self.resolve(typ) self.resolve(Optional[typ]) @@ -502,27 +503,37 @@ def _resolve_from_annotation( anno: Union[Annotation, DelayedAnnotation, ForwardDelayedAnnotation], _des: bool = True, _ser: bool = True, + _namespace: Type = None, ) -> SerdeProtocol: + if anno in self.__cache: + return self.__cache[anno] if isinstance(anno, (DelayedAnnotation, ForwardDelayedAnnotation)): return DelayedSerdeProtocol(anno) + # FIXME: Simulate legacy behavior. Should add runtime analysis soon (#95) if anno.origin is Callable: _des, _ser = False, False # Build the deserializer deserializer, validator, constraints = None, None, None if _des: - constraints = constr.get_constraints(anno.resolved, nullable=anno.optional) - deserializer, validator = self.des.factory(anno, constraints) + constraints = constr.get_constraints( + anno.resolved, nullable=anno.optional, cls=_namespace + ) + deserializer, validator = self.des.factory( + anno, constraints, namespace=_namespace + ) # Build the serializer serializer: Optional[SerializerT] = self.ser.factory(anno) if _ser else None # Put it all together - return SerdeProtocol( + proto = SerdeProtocol( annotation=anno, deserializer=deserializer, serializer=serializer, constraints=constraints, validator=validator, ) + self.__cache[anno] = proto + return proto @functools.lru_cache(maxsize=None) def resolve( @@ -582,7 +593,7 @@ def resolve( flags=flags, namespace=namespace, ) - resolved = self._resolve_from_annotation(anno, _des, _ser) + resolved = self._resolve_from_annotation(anno, _des, _ser, namespace) return resolved @functools.lru_cache(maxsize=None) diff --git a/typic/util.py b/typic/util.py index 6da1d9a5..9a91685c 100644 --- a/typic/util.py +++ b/typic/util.py @@ -47,7 +47,6 @@ "filtered_repr", "get_args", "get_name", - "hexhash", "origin", "resolve_supertype", "safe_eval", @@ -82,10 +81,6 @@ } -def hexhash(*args, __order=sys.byteorder, **kwargs) -> str: - return hash(f"{args}{kwargs}").to_bytes(8, __order, signed=True).hex() - - @functools.lru_cache(maxsize=2000, typed=True) def safe_eval(string: str) -> Tuple[bool, Any]: """Try a few methods to evaluate a string and get the correct Python data-type.