Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type hints #373

Merged
merged 13 commits into from Feb 4, 2019
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -29,6 +29,7 @@ script:
- coverage combine

- make mypy
- make external-mypy
- make docs
- BENCHMARK_REPEATS=1 make benchmark-all
- ./tests/check_tag.py
Expand Down
1 change: 1 addition & 0 deletions HISTORY.rst
Expand Up @@ -10,6 +10,7 @@ v0.19.0 (unreleased)
* Add ``multiple_of`` constraint to ``ConstrainedDecimal``, ``ConstrainedFloat``, ``ConstrainedInt``
and their related types ``condecimal``, ``confloat``, and ``conint`` #371, thanks @StephenBrown2
* Deprecated ``ignore_extra`` and ``allow_extra`` Config fields in favor of ``extra``, #352 by @liiight
* Add type annotations to all functions, test fully with mypy, #373 by @samuelcolvin

v0.18.2 (2019-01-22)
....................
Expand Down
24 changes: 16 additions & 8 deletions Makefile
Expand Up @@ -18,20 +18,26 @@ lint:
pytest pydantic -p no:sugar -q
black -S -l 120 --py36 --check pydantic tests

.PHONY: mypy
mypy:
mypy pydantic

.PHONY: test
test:
pytest --cov=pydantic

.PHONY: mypy
mypy:
.PHONY: external-mypy
external-mypy:
@echo "testing simple example with mypy (and python to check it's sane)..."
mypy --ignore-missing-imports --follow-imports=skip --strict-optional tests/mypy_test_success.py
python tests/mypy_test_success.py
@echo "checking code with bad type annotations fails..."
@mypy --ignore-missing-imports --follow-imports=skip tests/mypy_test_fails.py 1>/dev/null; \
mypy tests/mypy_test_success.py
@echo "checking code with incorrect types fails..."
@mypy tests/mypy_test_fails1.py 1>/dev/null; \
test $$? -eq 1 || \
(echo "mypy_test_fails1: mypy passed when it should have failed!"; exit 1)
@mypy tests/mypy_test_fails2.py 1>/dev/null; \
test $$? -eq 1 || \
(echo "mypy passed when it shouldn't"; exit 1)
python tests/mypy_test_fails.py
(echo "mypy_test_fails2: mypy passed when it should have failed!"; exit 1)

.PHONY: testcov
testcov:
Expand All @@ -40,7 +46,7 @@ testcov:
@coverage html

.PHONY: all
all: testcov mypy lint
all: testcov lint mypy external-mypy

.PHONY: benchmark-all
benchmark-all:
Expand All @@ -57,6 +63,8 @@ clean:
rm -f `find . -type f -name '*~' `
rm -f `find . -type f -name '.*~' `
rm -rf .cache
rm -rf .pytest_cache
rm -rf .mypy_cache
rm -rf htmlcov
rm -rf *.egg-info
rm -f .coverage
Expand Down
8 changes: 8 additions & 0 deletions docs/index.rst
Expand Up @@ -130,6 +130,14 @@ created by the standard library ``dataclass`` decorator.
``pydantic.dataclasses.dataclass``'s arguments are the same as the standard decorator, except one extra
key word argument ``config`` which has the same meaning as :ref:`Config <config>`.

.. note::

As a side effect of getting pydantic dataclasses to play nicely with mypy the ``config`` argument will show
as invalid in IDEs and mypy, use ``@dataclass(..., config=Config) # type: ignore`` as a workaround. See
`python/mypy#6239 <https://github.com/python/mypy/issues/6239>`_ for an explanation of why this is.

Nested dataclasses
~~~~~~~~~~~~~~~~~~

Since version ``v0.17`` nested dataclasses are support both in dataclasses and normal models.

Expand Down
38 changes: 22 additions & 16 deletions pydantic/class_validators.py
Expand Up @@ -3,10 +3,10 @@
from enum import IntEnum
from itertools import chain
from types import FunctionType
from typing import Callable, Dict
from typing import Any, Callable, Dict, List, Optional, Set

from .errors import ConfigError
from .utils import in_ipython
from .utils import AnyCallable, in_ipython


class ValidatorSignature(IntEnum):
Expand All @@ -18,17 +18,19 @@ class ValidatorSignature(IntEnum):

@dataclass
class Validator:
func: Callable
func: AnyCallable
pre: bool
whole: bool
always: bool
check_fields: bool


_FUNCS = set()
_FUNCS: Set[str] = set()


def validator(*fields, pre: bool = False, whole: bool = False, always: bool = False, check_fields: bool = True):
def validator(
*fields: str, pre: bool = False, whole: bool = False, always: bool = False, check_fields: bool = True
) -> Callable[[AnyCallable], classmethod]:
"""
Decorate methods on the class indicating that they should be used to validate fields
:param fields: which field(s) the method should be called on
Expand All @@ -45,7 +47,7 @@ def validator(*fields, pre: bool = False, whole: bool = False, always: bool = Fa
"E.g. usage should be `@validator('<field_name>', ...)`"
)

def dec(f):
def dec(f: AnyCallable) -> classmethod:
# avoid validators with duplicated names since without this validators can be overwritten silently
# which generally isn't the intended behaviour, don't run in ipython - see #312
if not in_ipython(): # pragma: no branch
Expand All @@ -54,26 +56,30 @@ def dec(f):
raise ConfigError(f'duplicate validator function "{ref}"')
_FUNCS.add(ref)
f_cls = classmethod(f)
f_cls.__validator_config = fields, Validator(f, pre, whole, always, check_fields)
f_cls.__validator_config = fields, Validator(f, pre, whole, always, check_fields) # type: ignore
return f_cls

return dec


ValidatorListDict = Dict[str, List[Validator]]


class ValidatorGroup:
def __init__(self, validators):
self.validators: Dict[str, Validator] = validators
def __init__(self, validators: ValidatorListDict) -> None:
self.validators = validators
self.used_validators = {'*'}

def get_validators(self, name):
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
self.used_validators.add(name)
specific_validators = self.validators.get(name)
wildcard_validators = self.validators.get('*')
if specific_validators or wildcard_validators:
validators = (specific_validators or []) + (wildcard_validators or [])
return {v.func.__name__: v for v in validators}
return None

def check_for_unused(self):
def check_for_unused(self) -> None:
unused_validators = set(
chain(
*[
Expand All @@ -90,8 +96,8 @@ def check_for_unused(self):
)


def extract_validators(namespace):
validators = {}
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
validators: Dict[str, List[Validator]] = {}
for var_name, value in namespace.items():
validator_config = getattr(value, '__validator_config', None)
if validator_config:
Expand All @@ -104,22 +110,22 @@ def extract_validators(namespace):
return validators


def inherit_validators(base_validators, validators):
def inherit_validators(base_validators: ValidatorListDict, validators: ValidatorListDict) -> ValidatorListDict:
for field, field_validators in base_validators.items():
if field not in validators:
validators[field] = []
validators[field] += field_validators
return validators


def get_validator_signature(validator):
def get_validator_signature(validator: Any) -> ValidatorSignature:
signature = inspect.signature(validator)

# bind here will raise a TypeError so:
# 1. we can deal with it before validation begins
# 2. (more importantly) it doesn't get confused with a TypeError when executing the validator
try:
if 'cls' in signature._parameters:
if 'cls' in signature._parameters: # type: ignore
if len(signature.parameters) == 2:
signature.bind(object(), 1)
return ValidatorSignature.CLS_JUST_VALUE
Expand Down
89 changes: 65 additions & 24 deletions pydantic/dataclasses.py
@@ -1,33 +1,50 @@
import dataclasses
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Optional, Type, Union

from . import ValidationError, errors
from .main import create_model, validate_model
from .utils import AnyType

if TYPE_CHECKING: # pragma: no cover
from .main import BaseConfig, BaseModel # noqa: F401

def _pydantic_post_init(self):
class DataclassType:
__pydantic_model__: Type[BaseModel]
__post_init_original__: Callable[..., None]
__initialised__: bool

def __init__(self, *args: Any, **kwargs: Any) -> None:
pass

@classmethod
def __validate__(cls, v: Any) -> 'DataclassType':
pass


def _pydantic_post_init(self: 'DataclassType') -> None:
d = validate_model(self.__pydantic_model__, self.__dict__)
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if self.__post_init_original__:
self.__post_init_original__()


def _validate_dataclass(cls, v):
def _validate_dataclass(cls: Type['DataclassType'], v: Any) -> 'DataclassType':
if isinstance(v, cls):
return v
elif isinstance(v, (cls, list, tuple)):
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
else:
raise errors.DataclassTypeError(class_name=cls.__name__)


def _get_validators(cls):
def _get_validators(cls: Type['DataclassType']) -> Generator[Any, None, None]:
yield cls.__validate__


def setattr_validate_assignment(self, name, value):
def setattr_validate_assignment(self: 'DataclassType', name: str, value: Any) -> None:
if self.__initialised__:
d = dict(self.__dict__)
d.pop(name)
Expand All @@ -38,17 +55,26 @@ def setattr_validate_assignment(self, name, value):
object.__setattr__(self, name, value)


def _process_class(_cls, init, repr, eq, order, unsafe_hash, frozen, config):
def _process_class(
_cls: AnyType,
init: bool,
repr: bool,
eq: bool,
order: bool,
unsafe_hash: bool,
frozen: bool,
config: Type['BaseConfig'],
) -> 'DataclassType':
post_init_original = getattr(_cls, '__post_init__', None)
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
post_init_original = None
_cls.__post_init__ = _pydantic_post_init
cls = dataclasses._process_class(_cls, init, repr, eq, order, unsafe_hash, frozen)
cls = dataclasses._process_class(_cls, init, repr, eq, order, unsafe_hash, frozen) # type: ignore

fields = {name: (field.type, field.default) for name, field in cls.__dataclass_fields__.items()}
fields: Dict[str, Any] = {name: (field.type, field.default) for name, field in cls.__dataclass_fields__.items()}
cls.__post_init_original__ = post_init_original

cls.__pydantic_model__ = create_model(cls.__name__, __config__=config, **fields)
cls.__pydantic_model__ = create_model(cls.__name__, __config__=config, __base__=None, **fields)

cls.__initialised__ = False
cls.__validate__ = classmethod(_validate_dataclass)
Expand All @@ -60,18 +86,33 @@ def _process_class(_cls, init, repr, eq, order, unsafe_hash, frozen, config):
return cls


def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, config=None):
"""
Like the python standard lib dataclasses but with type validation.

Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
as Config.validate_assignment.
"""

def wrap(cls):
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)

if _cls is None:
return wrap

return wrap(_cls)
if TYPE_CHECKING: # pragma: no cover
# see https://github.com/python/mypy/issues/6239 for explanation of why we do this
from dataclasses import dataclass
else:

def dataclass(
_cls: Optional[AnyType] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Type['BaseConfig'] = None,
) -> Union[Callable[[AnyType], 'DataclassType'], 'DataclassType']:
"""
Like the python standard lib dataclasses but with type validation.

Arguments are the same as for standard dataclasses, except for validate_assignment which has the same meaning
as Config.validate_assignment.
"""

def wrap(cls: AnyType) -> 'DataclassType':
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, config)

if _cls is None:
return wrap

return wrap(_cls)