Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jan 27, 2019
1 parent c5d0487 commit 77eba2b
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 95 deletions.
52 changes: 42 additions & 10 deletions pydantic/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,49 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Optional, Type, Union

from . import ValidationError, errors
from .main import create_model, validate_model

if TYPE_CHECKING: # pragma: no cover
from .main import BaseConfig, BaseModel # noqa

class DataclassType:
__pydantic_model__: Type[BaseModel]
__post_init_original__: Callable[..., None]
__initialised__: bool

def __init__(self, *args: Any, **kwargs: Any) -> None:
pass

@classmethod
def __validate__(cls, v: Any) -> 'DataclassType':
pass


def _pydantic_post_init(self):
def _pydantic_post_init(self: 'DataclassType') -> None:
d = validate_model(self.__pydantic_model__, self.__dict__)
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if self.__post_init_original__:
self.__post_init_original__()


def _validate_dataclass(cls, v):
def _validate_dataclass(cls: Type['DataclassType'], v: Any) -> 'DataclassType':
if isinstance(v, cls):
return v
elif isinstance(v, (cls, list, tuple)):
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
else:
raise errors.DataclassTypeError(class_name=cls.__name__)


def _get_validators(cls):
def _get_validators(cls: Type['DataclassType']) -> Generator[Any, None, None]:
yield cls.__validate__


def setattr_validate_assignment(self, name, value):
def setattr_validate_assignment(self: 'DataclassType', name: str, value: Any) -> None:
if self.__initialised__:
d = dict(self.__dict__)
d.pop(name)
Expand All @@ -39,7 +54,16 @@ def setattr_validate_assignment(self, name, value):
object.__setattr__(self, name, value)


def _process_class(_cls, init, repr, eq, order, unsafe_hash, frozen, config):
def _process_class(
_cls: Type[Any],
init: bool,
repr: bool,
eq: bool,
order: bool,
unsafe_hash: bool,
frozen: bool,
config: Type['BaseConfig'],
) -> 'DataclassType':
post_init_original = getattr(_cls, '__post_init__', None)
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
post_init_original = None
Expand Down Expand Up @@ -67,16 +91,24 @@ def _process_class(_cls, init, repr, eq, order, unsafe_hash, frozen, config):
else:

def dataclass(
_cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, config=None
):
_cls: Optional[Type[Any]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type['BaseConfig'] = None,
) -> Union[Callable[[Type[Any]], 'DataclassType'], 'DataclassType']:
"""
Like the python standard lib dataclasses but with type validation.
Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
as Config.validate_assignment.
"""

def wrap(cls):
def wrap(cls: Type[Any]) -> 'DataclassType':
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)

if _cls is None:
Expand Down
111 changes: 64 additions & 47 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,22 @@
from functools import partial
from pathlib import Path
from types import FunctionType
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast, no_type_check
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
no_type_check,
)

from .class_validators import Validator, ValidatorGroup, extract_validators, inherit_validators
from .error_wrappers import ErrorWrapper, ValidationError
Expand Down Expand Up @@ -38,7 +53,7 @@ class BaseConfig:
json_encoders: Dict[Type[Any], Callable[..., Any]] = {}

@classmethod
def get_field_schema(cls, name):
def get_field_schema(cls, name: str) -> Dict[str, str]:
field_config = cls.fields.get(name) or {}
if isinstance(field_config, str):
field_config = {'alias': field_config}
Expand Down Expand Up @@ -132,16 +147,20 @@ def __new__(mcs, name, bases, namespace):


class BaseModel(metaclass=MetaModel):
# populated by the metaclass, defined here to help IDEs only
__fields__: Dict[str, Field] = {}
__validators__: Dict[str, Callable[..., Any]] = {}
__config__: BaseConfig = BaseConfig()
_json_encoder: Callable[[Any], Any] = lambda x: x # pragma: no branch
if TYPE_CHECKING: # pragma: no cover
# populated by the metaclass, defined here to help IDEs only
__fields__: Dict[str, Field] = {}
__validators__: Dict[str, Callable[..., Any]] = {}
__config__: BaseConfig = BaseConfig()
_json_encoder: Callable[[Any], Any] = lambda x: x
_schema_cache: Dict[Any, Any] = {}

Config = BaseConfig
__slots__ = ('__values__',)

def __init__(self, **data):
def __init__(self, **data: Any) -> None:
if TYPE_CHECKING:
self.__values__: Dict[Any, Any] = {}
self.__setstate__(self._process_values(data))

@no_type_check
Expand All @@ -166,10 +185,10 @@ def __setattr__(self, name, value):
else:
self.__values__[name] = value

def __getstate__(self):
def __getstate__(self) -> Dict[Any, Any]:
return self.__values__

def __setstate__(self, state):
def __setstate__(self, state: Dict[Any, Any]) -> None:
object.__setattr__(self, '__values__', state)

def dict(self, *, include: Set[str] = None, exclude: Set[str] = None, by_alias: bool = False) -> Dict[str, Any]:
Expand Down Expand Up @@ -199,7 +218,7 @@ def json(
exclude: Set[str] = None,
by_alias: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
**dumps_kwargs,
**dumps_kwargs: Any,
) -> str:
"""
Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`.
Expand All @@ -212,7 +231,7 @@ def json(
)

@classmethod
def parse_obj(cls, obj):
def parse_obj(cls, obj: Dict[Any, Any]) -> 'BaseModel':
if not isinstance(obj, dict):
exc = TypeError(f'{cls.__name__} expected dict not {type(obj).__name__}')
raise ValidationError([ErrorWrapper(exc, loc='__obj__')])
Expand All @@ -227,7 +246,7 @@ def parse_raw(
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
):
) -> 'BaseModel':
try:
obj = load_str_bytes(
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle
Expand All @@ -245,12 +264,12 @@ def parse_file(
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
):
) -> 'BaseModel':
obj = load_file(path, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle)
return cls.parse_obj(obj)

@classmethod
def construct(cls, **values):
def construct(cls, **values: Any) -> 'BaseModel':
"""
Creates a new model and set __values__ without any validation, thus values should already be trusted.
Chances are you don't want to use this method directly.
Expand All @@ -261,7 +280,7 @@ def construct(cls, **values):

def copy(
self, *, include: Set[str] = None, exclude: Set[str] = None, update: Dict[str, Any] = None, deep: bool = False
):
) -> 'BaseModel':
"""
Duplicate a model, optionally choose which fields to include, exclude and change.
Expand All @@ -274,10 +293,10 @@ def copy(
"""
if include is None and exclude is None and update is None:
# skip constructing values if no arguments are passed
v = self.__values__ # type: ignore
v = self.__values__
else:
exclude = exclude or set()
v_ = self.__values__ # type: ignore
v_ = self.__values__
v = {
**{k: v for k, v in v_.items() if k not in exclude and (not include or k in include)},
**(update or {}),
Expand All @@ -287,38 +306,38 @@ def copy(
return self.__class__.construct(**v)

@property
def fields(self):
def fields(self) -> Dict[str, Field]:
return self.__fields__

@classmethod
def schema(cls, by_alias=True) -> Dict[str, Any]:
cached = cls._schema_cache.get(by_alias) # type: ignore
def schema(cls, by_alias: bool = True) -> Dict[str, Any]:
cached = cls._schema_cache.get(by_alias)
if cached is not None:
return cached
s = model_schema(cls, by_alias=by_alias)
cls._schema_cache[by_alias] = s # type: ignore
cls._schema_cache[by_alias] = s
return s

@classmethod
def schema_json(cls, *, by_alias=True, **dumps_kwargs) -> str:
def schema_json(cls, *, by_alias: bool = True, **dumps_kwargs: Any) -> str:
from .json import pydantic_encoder

return json.dumps(cls.schema(by_alias=by_alias), default=pydantic_encoder, **dumps_kwargs)

@classmethod
def __get_validators__(cls):
def __get_validators__(cls) -> Generator[Any, None, None]:
yield dict_validator
yield cls.validate

@classmethod
def validate(cls, value):
def validate(cls, value: Dict[str, Any]) -> 'BaseModel':
return cls(**value)

def _process_values(self, input_data: Any) -> Dict[str, Any]:
return validate_model(self, input_data)
return cast(Dict[str, Any], validate_model(self, input_data))

@classmethod
def _get_value(cls, v, by_alias=False):
def _get_value(cls, v: Any, by_alias: bool) -> Any:
if isinstance(v, BaseModel):
return v.dict(by_alias=by_alias)
elif isinstance(v, list):
Expand All @@ -333,55 +352,51 @@ def _get_value(cls, v, by_alias=False):
return v

@classmethod
def update_forward_refs(cls, **localns):
def update_forward_refs(cls, **localns: Any) -> None:
"""
Try to update ForwardRefs on fields based on this Model, globalns and localns.
"""
globalns = sys.modules[cls.__module__].__dict__
globalns.setdefault(cls.__name__, cls)
for f in cls.__fields__.values():
if type(f.type_) == ForwardRef:
f.type_ = f.type_._evaluate(globalns, localns or None)
f.type_ = f.type_._evaluate(globalns, localns or None) # type: ignore
f.prepare()

def __iter__(self):
def __iter__(self) -> Generator[Any, None, None]:
"""
so `dict(model)` works
"""
yield from self._iter()

def _iter(self, by_alias=False):
def _iter(self, by_alias: bool = False) -> Generator[Tuple[str, Any], None, None]:
for k, v in self.__values__.items():
yield k, self._get_value(v, by_alias=by_alias)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
return self.dict() == other.dict()
else:
return self.dict() == other

def __repr__(self):
def __repr__(self) -> str:
return f'<{self}>'

def to_string(self, pretty=False):
def to_string(self, pretty: bool = False) -> str:
divider = '\n ' if pretty else ' '
return '{}{}{}'.format(
self.__class__.__name__,
divider,
divider.join('{}={}'.format(k, truncate(v)) for k, v in self.__values__.items()),
)

def __str__(self):
def __str__(self) -> str:
return self.to_string()


def create_model(
model_name: str,
*,
__config__: Type[BaseConfig] = None,
__base__: Type[BaseModel] = None,
**field_definitions: Dict[str, Any],
):
model_name: str, *, __config__: Type[BaseConfig] = None, __base__: Type[BaseModel] = None, **field_definitions: Any
) -> BaseModel:
"""
Dynamically create a model.
:param model_name: name of the created model
Expand Down Expand Up @@ -423,17 +438,19 @@ def create_model(
if __config__:
namespace['Config'] = inherit_config(__config__, BaseConfig)

return type(model_name, (__base__,), namespace)
return cast(BaseModel, type(model_name, (__base__,), namespace))


def validate_model(model: BaseModel, input_data: Dict[str, Any], raise_exc=True): # noqa: C901 (ignore complexity)
def validate_model( # noqa: C901 (ignore complexity)
model: Union[BaseModel, Type[BaseModel]], input_data: Dict[str, Any], raise_exc: bool = True
) -> Union[Dict[str, Any], Tuple[Dict[str, Any], Optional[ValidationError]]]:
"""
validate data against a model.
"""
values = {}
errors = []
names_used = set()
check_extra = (not model.__config__.ignore_extra) or model.__config__.allow_extra
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
names_used: Set[str] = set()
check_extra: bool = (not model.__config__.ignore_extra) or model.__config__.allow_extra

for name, field in model.__fields__.items():
if type(field.type_) == ForwardRef:
Expand Down
4 changes: 2 additions & 2 deletions pydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
**extra: Any,
) -> None:
self.default = default
self.alias = alias
Expand Down Expand Up @@ -680,7 +680,7 @@ def encode_default(dft: Any) -> Any:
_map_types_constraint: Dict[Any, Callable[..., type]] = {int: conint, float: confloat, Decimal: condecimal}


def get_annotation_from_schema(annotation: Any, schema: Schema) -> Any:
def get_annotation_from_schema(annotation: Any, schema: Schema) -> Type[Any]:
"""
Get an annotation with validation implemented for numbers and strings based on the schema.
Expand Down
Loading

0 comments on commit 77eba2b

Please sign in to comment.