Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Literal annotation #582

Merged
merged 11 commits into from
Jun 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 2 additions & 8 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ install:
- pip freeze

script:
# test without cython but with ujson and email-validator
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(1 if pydantic.compiled else 0)"
- make test

Expand All @@ -40,7 +39,6 @@ jobs:
python: 3.6
name: 'Cython: 3.6'
script:
# test with cython, ujson and email-validator
- make build-cython-trace
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
- make test
Expand All @@ -50,7 +48,6 @@ jobs:
python: 3.7
name: 'Cython: 3.7'
script:
# test with cython, ujson and email-validator
- make build-cython-trace
- python -c "import sys, pydantic; print('compiled:', pydantic.compiled); sys.exit(0 if pydantic.compiled else 1)"
- make test
Expand All @@ -61,17 +58,15 @@ jobs:
python: 3.6
name: 'Without Deps 3.6'
script:
# test without cython, ujson and email-validator
- pip uninstall -y ujson email-validator
- pip uninstall -y ujson email-validator typing-extensions
- make test
env:
- 'DEPS=no'
- stage: test
python: 3.7
name: 'Without Deps 3.7'
script:
# test without cython, ujson and email-validator
- pip uninstall -y ujson email-validator cython
- pip uninstall -y ujson email-validator cython typing-extensions
- make test
env:
- 'DEPS=no'
Expand All @@ -80,7 +75,6 @@ jobs:
python: 3.7
name: 'Benchmarks'
script:
# default install skips cython compilation, need to compile for benchmarks
- make build-cython
- BENCHMARK_REPEATS=1 make benchmark-all
after_success: skip
Expand Down
5 changes: 4 additions & 1 deletion pydantic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class NoneIsAllowedError(PydanticTypeError):

class WrongConstantError(PydanticValueError):
code = 'const'
msg_template = 'expected constant value {const!r}'

def __str__(self) -> str:
permitted = ', '.join(repr(v) for v in self.ctx['permitted']) # type: ignore
return f'unexpected value; permitted: {permitted}'


class BytesError(PydanticTypeError):
Expand Down
7 changes: 7 additions & 0 deletions pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from .utils import AnyCallable, AnyType, Callable, ForwardRef, display_as_type, lenient_issubclass, sequence_like
from .validators import NoneType, constant_validator, dict_validator, find_validators

try:
from typing_extensions import Literal
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
Literal = None # type: ignore

Required: Any = Ellipsis

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -187,6 +192,8 @@ def _populate_sub_fields(self) -> None: # noqa: C901 (ignore complexity)
return
if origin is Callable:
return
if Literal is not None and origin is Literal:
return
if origin is Union:
types_ = []
for type_ in self.type_.__args__: # type: ignore
Expand Down
17 changes: 17 additions & 0 deletions pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

import pydantic

try:
from typing_extensions import Literal
except ImportError:
Literal = None # type: ignore

try:
import email_validator
except ImportError:
Expand Down Expand Up @@ -285,6 +290,18 @@ def is_callable_type(type_: AnyType) -> bool:
return type_ is Callable or getattr(type_, '__origin__', None) is Callable


if sys.version_info >= (3, 7):

def is_literal_type(type_: AnyType) -> bool:
return Literal is not None and getattr(type_, '__origin__', None) is Literal


else:

def is_literal_type(type_: AnyType) -> bool:
return Literal is not None and hasattr(type_, '__values__') and type_ == Literal[type_.__values__]
dmontagu marked this conversation as resolved.
Show resolved Hide resolved


def _check_classvar(v: AnyType) -> bool:
return type(v) == type(ClassVar) and (sys.version_info < (3, 7) or getattr(v, '_name', None) == 'ClassVar')

Expand Down
36 changes: 33 additions & 3 deletions pydantic/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import sys
from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from decimal import Decimal, DecimalException
Expand All @@ -24,7 +25,16 @@

from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .utils import AnyCallable, AnyType, ForwardRef, change_exception, display_as_type, is_callable_type, sequence_like
from .utils import (
AnyCallable,
AnyType,
ForwardRef,
change_exception,
display_as_type,
is_callable_type,
is_literal_type,
sequence_like,
)

if TYPE_CHECKING: # pragma: no cover
from .fields import Field
Expand Down Expand Up @@ -140,7 +150,7 @@ def constant_validator(v: 'Any', field: 'Field') -> 'Any':
Schema.
"""
if v != field.default:
raise errors.WrongConstantError(given=v, const=field.default)
raise errors.WrongConstantError(given=v, permitted=[field.default])

return v

Expand Down Expand Up @@ -334,6 +344,21 @@ def callable_validator(v: Any) -> AnyCallable:
raise errors.CallableError(value=v)


def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
if sys.version_info >= (3, 7):
permitted_choices = type_.__args__
else:
permitted_choices = type_.__values__
allowed_choices_set = set(permitted_choices)

def literal_validator(v: Any) -> Any:
if v not in allowed_choices_set:
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
return v

return literal_validator


T = TypeVar('T')


Expand Down Expand Up @@ -409,7 +434,9 @@ def check(self, config: Type['BaseConfig']) -> bool:
]


def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[AnyCallable, None, None]:
def find_validators( # noqa: C901 (ignore complexity)
type_: AnyType, config: Type['BaseConfig']
) -> Generator[AnyCallable, None, None]:
if type_ is Any:
return
type_type = type(type_)
Expand All @@ -421,6 +448,9 @@ def find_validators(type_: AnyType, config: Type['BaseConfig']) -> Generator[Any
if is_callable_type(type_):
yield callable_validator
return
if is_literal_type(type_):
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
yield make_literal_validator(type_)
return

supertype = _find_supertype(type_)
if supertype is not None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
ujson==1.35
email-validator==1.0.4
dataclasses==0.6; python_version < '3.7'
typing-extensions==3.7.2
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def extra(self):
extras_require={
'ujson': ['ujson>=1.35'],
'email': ['email-validator>=1.0.3'],
'typing_extensions': ['typing-extensions>=3.7.2']
dmontagu marked this conversation as resolved.
Show resolved Hide resolved
},
ext_modules=ext_modules,
)
4 changes: 2 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,9 @@ class Model(BaseModel):
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': 'expected constant value 3',
'msg': 'unexpected value; permitted: 3',
'type': 'value_error.const',
'ctx': {'given': 4, 'const': 3},
'ctx': {'given': 4, 'permitted': [3]},
}
]

Expand Down
42 changes: 42 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
except ImportError:
email_validator = None

try:
import typing_extensions
except ImportError:
typing_extensions = None


class ConBytesModel(BaseModel):
v: conbytes(max_length=10) = b'foobar'
Expand Down Expand Up @@ -1661,3 +1666,40 @@ class Model(BaseModel):
{'loc': ('generic_list',), 'msg': 'value is not a valid list', 'type': 'type_error.list'},
{'loc': ('generic_dict',), 'msg': 'value is not a valid dict', 'type': 'type_error.dict'},
]


@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
def test_literal_single():
class Model(BaseModel):
a: typing_extensions.Literal['a']

Model(a='a')
with pytest.raises(ValidationError) as exc_info:
Model(a='b')
assert exc_info.value.errors() == [
{
'loc': ('a',),
'msg': "unexpected value; permitted: 'a'",
'type': 'value_error.const',
'ctx': {'given': 'b', 'permitted': ('a',)},
}
]


@pytest.mark.skipif(not typing_extensions, reason='typing_extensions not installed')
def test_literal_multiple():
class Model(BaseModel):
a_or_b: typing_extensions.Literal['a', 'b']

Model(a_or_b='a')
Model(a_or_b='b')
with pytest.raises(ValidationError) as exc_info:
Model(a_or_b='c')
assert exc_info.value.errors() == [
{
'loc': ('a_or_b',),
'msg': "unexpected value; permitted: 'a', 'b'",
'type': 'value_error.const',
'ctx': {'given': 'c', 'permitted': ('a', 'b')},
}
]