Skip to content

Commit

Permalink
Start enumerating tests for constraints factory.
Browse files Browse the repository at this point in the history
Also, add support for unbound forward references.
  • Loading branch information
seandstewart committed Apr 5, 2023
1 parent b4394ef commit 88d218d
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 114 deletions.
40 changes: 12 additions & 28 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,28 @@ branch = True
data_file = coverage.db
include =
src/*
omit =
*__init__.py
tests/*

[paths]
source =
src/
*/src/*

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: nocover
pragma: nobranch
pragma: no cover
pragma: no branch

# Don't complain about missing debug-only code:
sort = Cover
exclude_also =
def __repr__
if self\.debug
\.\.\.

# Don't complain if tests don't hit defensive assertion code:
def __str__
if self.debug:
if settings.DEBUG
raise AssertionError
raise NotImplementedError

# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:
if TYPE_CHECKING:
class.*\(Protocol.*\):
\@abc\.abstractmethod

omit =
setup.py
.*env*
*lib/python*
dist*
tests*
benchmark*
docs*
mypy.py
class\s\w+\((typing\.)?Protocol(\[.*\])?\):
@(abc\.)?abstractmethod
@(typing\.)?overload

skip_empty = True
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ venv.bak/
/pip-wheel-metadata/
/benchmark/.cases.json
*.DS_Store
coverage.db
103 changes: 52 additions & 51 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/typical/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def wrap(cls):

_stack.add(key)

if sys.version_info >= (3, 10) and "typical" not in cls.__qualname__:
if sys.version_info >= (3, 10) and "typical" not in cls.__module__:
warnings.warn(
f"You are using Python {sys.version}. "
"Python 3.10 introduced native support for slotted dataclasses. "
Expand Down
3 changes: 1 addition & 2 deletions src/typical/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@


class empty:
def __bool__(self):
return False
"""A singleton for signalling no input."""


DEFAULT_ENCODING = "utf-8"
Expand Down
10 changes: 4 additions & 6 deletions src/typical/core/constraints/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numbers
import re
import reprlib
import sys
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -344,7 +343,7 @@ def error(
class DelayedConstraintValidator(AbstractConstraintValidator[_VT]):
__slots__ = (
"ref",
"module",
"globalns",
"localns",
"nullable",
"readonly",
Expand All @@ -357,7 +356,7 @@ class DelayedConstraintValidator(AbstractConstraintValidator[_VT]):
def __init__(
self,
ref: ForwardRef | type,
module: str,
globalns: Mapping,
localns: Mapping,
nullable: bool,
readonly: bool,
Expand All @@ -367,7 +366,7 @@ def __init__(
**config,
):
self.ref = ref
self.module = module
self.globalns = globalns
self.localns = localns
self.nullable = nullable
self.readonly = readonly
Expand All @@ -381,9 +380,8 @@ def __init__(
def _evaluate_reference(self) -> AbstractConstraintValidator[_VT]:
type = self.ref
if isinstance(self.ref, ForwardRef):
globalns = sys.modules[self.module].__dict__.copy()
try:
type = evaluate_forwardref(self.ref, globalns or {}, self.localns or {})
type = evaluate_forwardref(self.ref, self.globalns, self.localns)
except NameError as e: # pragma: nocover
warnings.warn(
f"Counldn't resolve forward reference: {e}. "
Expand Down
51 changes: 25 additions & 26 deletions src/typical/core/constraints/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import decimal as stdlib_decimal
import functools
import inspect
import typing
from typing import Any, Callable, Collection, Hashable, TypeVar, Union, cast

from typical import checks, inspection
Expand Down Expand Up @@ -56,6 +57,16 @@ def build(
default: Hashable | Callable[[], VT] | constants.empty = constants.empty,
**config,
) -> types.AbstractConstraintValidator:
if t in self.NOOP:
return self._from_undeclared_type(
t=t,
nullable=nullable,
readonly=readonly,
writeonly=writeonly,
cls=cls,
default=default,
)

if hasattr(t, "__constraints__"):
return t.__constraints__ # type: ignore[attr-defined]

Expand All @@ -73,27 +84,6 @@ def build(

t = args[0] if len(args) == 1 else Union[args]

if t in (Any, ..., type(...)):
return engine.ConstraintValidator(
constraints=types.TypeConstraints(
type=t,
nullable=nullable,
readonly=readonly,
writeonly=writeonly,
default=default,
),
validator=validators.NoOpInstanceValidator(
type=t,
precheck=validators.NoOpPrecheck(
type=t,
nullable=nullable,
readonly=readonly,
writeonly=writeonly,
name=name,
**config,
),
),
)
if t is cls or t in self.__visited:
module = getattr(t, "__module__", None)
if cls and cls is not ...:
Expand All @@ -116,14 +106,23 @@ def build(
t = ForwardRef(str(t)) # type: ignore[assignment]

if checks.isforwardref(t):
# If we don't have an enclosing scope, search for one.
if not cls or cls is ...:
raise TypeError(
f"Cannot build constraints for {t} without an enclosing class."
)
caller = inspection.getcaller()
globalns, localns = caller.f_globals, caller.f_locals
# Otherwise, use the context from the enclosing scope.
else:
globalns = {}
module = inspect.getmodule(cls)
if module:
globalns = vars(module)
localns = dict(inspect.getmembers(cls))

globalns.update(typing=typing)
return types.DelayedConstraintValidator(
ref=t,
module=cls.__module__,
localns=(getattr(cls, "__dict__", None) or {}).copy(),
globalns=globalns,
localns=localns,
nullable=nullable,
readonly=readonly,
writeonly=writeonly,
Expand Down
24 changes: 24 additions & 0 deletions src/typical/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,30 @@ def extract(name: str, *, frame: types.FrameType = None) -> Optional[Any]:
return None


def getcaller(frame: types.FrameType = None) -> types.FrameType:
"""Get the caller of the current scope, excluding this library.
If `frame` is not provided, this function will use the current frame.
"""
if frame is None:
frame = inspect.currentframe()

while frame.f_back:
frame = frame.f_back
module = inspect.getmodule(frame)
if module and module.__name__.startswith("typical"):
continue

code = frame.f_code
if getattr(code, "co_qualname", "").startswith("typical"):
continue
if "typical" in code.co_filename:
continue
return frame

return frame


@lru_cache(maxsize=None)
def get_type_graph(t: Type) -> TypeGraph:
"""Get a directed graph of the type(s) this annotation represents."""
Expand Down
147 changes: 147 additions & 0 deletions tests/unit/core/constraints/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import decimal
import enum
import inspect
import typing

import pytest

from typical.core import constants
from typical.core.constraints import factory
from typical.core.constraints.core import types, validators


class MyEnum(enum.Enum):
...


class MyClass:
...


@pytest.mark.suite(
anytype=dict(
given_type=typing.Any,
given_context=dict(),
expected_constraints_cls=types.UndeclaredTypeConstraints,
expected_validator_cls=validators.NoOpInstanceValidator,
),
constants_empty=dict(
given_type=constants.empty,
given_context=dict(),
expected_constraints_cls=types.UndeclaredTypeConstraints,
expected_validator_cls=validators.NoOpInstanceValidator,
),
param_empty=dict(
given_type=inspect.Parameter.empty,
given_context=dict(),
expected_constraints_cls=types.UndeclaredTypeConstraints,
expected_validator_cls=validators.NoOpInstanceValidator,
),
ellipsis=dict(
given_type=Ellipsis,
given_context=dict(),
expected_constraints_cls=types.UndeclaredTypeConstraints,
expected_validator_cls=validators.NoOpInstanceValidator,
),
enum=dict(
given_type=MyEnum,
given_context=dict(),
expected_constraints_cls=types.EnumerationConstraints,
expected_validator_cls=validators.OneOfValidator,
),
literal=dict(
given_type=typing.Literal[1, 2],
given_context=dict(),
expected_constraints_cls=types.EnumerationConstraints,
expected_validator_cls=validators.OneOfValidator,
),
string=dict(
given_type=str,
given_context=dict(),
expected_constraints_cls=types.TextConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
bytestring=dict(
given_type=bytes,
given_context=dict(),
expected_constraints_cls=types.TextConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
boolean=dict(
given_type=bool,
given_context=dict(),
expected_constraints_cls=types.TypeConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
integer=dict(
given_type=int,
given_context=dict(),
expected_constraints_cls=types.NumberConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
float=dict(
given_type=float,
given_context=dict(),
expected_constraints_cls=types.NumberConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
decimal=dict(
given_type=decimal.Decimal,
given_context=dict(),
expected_constraints_cls=types.DecimalConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
structured=dict(
given_type=MyClass,
given_context=dict(),
expected_constraints_cls=types.StructuredObjectConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
dict=dict(
given_type=dict,
given_context=dict(),
expected_constraints_cls=types.MappingConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
mapping=dict(
given_type=typing.Mapping,
given_context=dict(),
expected_constraints_cls=types.MappingConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
collection=dict(
given_type=typing.Collection,
given_context=dict(),
expected_constraints_cls=types.ArrayConstraints,
expected_validator_cls=validators.IsInstanceValidator,
),
optional=dict(
given_type=typing.Optional[str],
given_context=dict(),
expected_constraints_cls=types.TextConstraints,
expected_validator_cls=validators.NullableIsInstanceValidator,
),
)
def test_build(
given_type, given_context, expected_constraints_cls, expected_validator_cls
):
# When
built_cv = factory.build(t=given_type, **given_context)
# Then
assert isinstance(built_cv.constraints, expected_constraints_cls)
assert isinstance(built_cv.validator, expected_validator_cls)


def test_build_forwardref():
# Given
given_type = "dict | None"
expected_constraints_cls = types.MappingConstraints
expected_validator_cls = validators.NullableIsInstanceValidator
# When
built_dcv = factory.build(t=given_type)
built_cv = built_dcv.cv
# Then
assert isinstance(built_cv.constraints, expected_constraints_cls)
assert isinstance(built_cv.validator, expected_validator_cls)

0 comments on commit 88d218d

Please sign in to comment.