diff --git a/changes/921-samuelcolvin.md b/changes/921-samuelcolvin.md new file mode 100644 index 0000000000..afbdef521d --- /dev/null +++ b/changes/921-samuelcolvin.md @@ -0,0 +1 @@ +Allow abstracts sets (eg. dict keys) in the `include` and `exclude` arguments of `dict()` diff --git a/pydantic/main.py b/pydantic/main.py index a663927343..2f25db5d09 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -24,7 +24,7 @@ from .class_validators import ValidatorListDict from .types import ModelOrDc from .typing import CallableGenerator, TupleGenerator, DictStrAny, DictAny, SetStr - from .typing import SetIntStr, DictIntStrAny, ReprArgs # noqa: F401 + from .typing import AbstractSetIntStr, DictIntStrAny, ReprArgs # noqa: F401 ConfigType = Type['BaseConfig'] Model = TypeVar('Model', bound='BaseModel') @@ -302,8 +302,8 @@ def __setstate__(self, state: 'DictAny') -> None: def dict( self, *, - include: Union['SetIntStr', 'DictIntStrAny'] = None, - exclude: Union['SetIntStr', 'DictIntStrAny'] = None, + include: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, by_alias: bool = False, skip_defaults: bool = None, exclude_unset: bool = False, @@ -344,8 +344,8 @@ def _get_key_factory(self, by_alias: bool) -> Callable[..., str]: def json( self, *, - include: Union['SetIntStr', 'DictIntStrAny'] = None, - exclude: Union['SetIntStr', 'DictIntStrAny'] = None, + include: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, by_alias: bool = False, skip_defaults: bool = None, exclude_unset: bool = False, @@ -452,8 +452,8 @@ def construct(cls: Type['Model'], values: 'DictAny', fields_set: 'SetStr') -> 'M def copy( self: 'Model', *, - include: Union['SetIntStr', 'DictIntStrAny'] = None, - exclude: Union['SetIntStr', 'DictIntStrAny'] = None, + include: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, update: 'DictStrAny' = None, deep: bool = False, ) -> 'Model': @@ -539,8 +539,8 @@ def _get_value( v: Any, to_dict: bool, by_alias: bool, - include: Optional[Union['SetIntStr', 'DictIntStrAny']], - exclude: Optional[Union['SetIntStr', 'DictIntStrAny']], + include: Optional[Union['AbstractSetIntStr', 'DictIntStrAny']], + exclude: Optional[Union['AbstractSetIntStr', 'DictIntStrAny']], exclude_unset: bool, exclude_defaults: bool, ) -> Any: @@ -616,8 +616,8 @@ def _iter( to_dict: bool = False, by_alias: bool = False, allowed_keys: Optional['SetStr'] = None, - include: Union['SetIntStr', 'DictIntStrAny'] = None, - exclude: Union['SetIntStr', 'DictIntStrAny'] = None, + include: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'DictIntStrAny'] = None, exclude_unset: bool = False, exclude_defaults: bool = False, ) -> 'TupleGenerator': @@ -646,8 +646,8 @@ def _iter( def _calculate_keys( self, - include: Optional[Union['SetIntStr', 'DictIntStrAny']], - exclude: Optional[Union['SetIntStr', 'DictIntStrAny']], + include: Optional[Union['AbstractSetIntStr', 'DictIntStrAny']], + exclude: Optional[Union['AbstractSetIntStr', 'DictIntStrAny']], exclude_unset: bool, update: Optional['DictStrAny'] = None, ) -> Optional['SetStr']: diff --git a/pydantic/typing.py b/pydantic/typing.py index 9cb9e6b6c3..170020a86d 100644 --- a/pydantic/typing.py +++ b/pydantic/typing.py @@ -2,6 +2,7 @@ from enum import Enum from typing import ( # type: ignore TYPE_CHECKING, + AbstractSet, Any, ClassVar, Dict, @@ -62,7 +63,7 @@ def evaluate_forwardref(type_, globalns, localns): # type: ignore SetStr = Set[str] ListStr = List[str] IntStr = Union[int, str] - SetIntStr = Set[IntStr] + AbstractSetIntStr = AbstractSet[IntStr] DictIntStrAny = Dict[IntStr, Any] CallableGenerator = Generator[AnyCallable, None, None] ReprArgs = Sequence[Tuple[Optional[str], Any]] @@ -89,10 +90,11 @@ def evaluate_forwardref(type_, globalns, localns): # type: ignore 'SetStr', 'ListStr', 'IntStr', - 'SetIntStr', + 'AbstractSetIntStr', 'DictIntStrAny', 'CallableGenerator', 'ReprArgs', + 'CallableGenerator', ) diff --git a/pydantic/utils.py b/pydantic/utils.py index d704b0c573..5860349100 100644 --- a/pydantic/utils.py +++ b/pydantic/utils.py @@ -3,6 +3,7 @@ from importlib import import_module from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, Dict, @@ -27,7 +28,7 @@ if TYPE_CHECKING: from .main import BaseModel # noqa: F401 - from .typing import SetIntStr, DictIntStrAny, IntStr, ReprArgs # noqa: F401 + from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, ReprArgs # noqa: F401 KeyType = TypeVar('KeyType') @@ -247,15 +248,15 @@ class ValueItems(Representation): __slots__ = ('_items', '_type') - def __init__(self, value: Any, items: Union['SetIntStr', 'DictIntStrAny']) -> None: + def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'DictIntStrAny']) -> None: if TYPE_CHECKING: - self._items: Union['SetIntStr', 'DictIntStrAny'] + self._items: Union['AbstractSetIntStr', 'DictIntStrAny'] self._type: Type[Union[set, dict]] # type: ignore # For further type checks speed-up if isinstance(items, dict): self._type = dict - elif isinstance(items, set): + elif isinstance(items, AbstractSet): self._type = set else: raise TypeError(f'Unexpected type of exclude value {type(items)}') @@ -288,7 +289,7 @@ def is_included(self, item: Any) -> bool: return item in self._items @no_type_check - def for_element(self, e: 'IntStr') -> Optional[Union['SetIntStr', 'DictIntStrAny']]: + def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'DictIntStrAny']]: """ :param e: key or index of element on value :return: raw values for elemet if self._items is dict and contain needed element @@ -301,8 +302,8 @@ def for_element(self, e: 'IntStr') -> Optional[Union['SetIntStr', 'DictIntStrAny @no_type_check def _normalize_indexes( - self, items: Union['SetIntStr', 'DictIntStrAny'], v_length: int - ) -> Union['SetIntStr', 'DictIntStrAny']: + self, items: Union['AbstractSetIntStr', 'DictIntStrAny'], v_length: int + ) -> Union['AbstractSetIntStr', 'DictIntStrAny']: """ :param items: dict or set of indexes which will be normalized :param v_length: length of sequence indexes of which will be diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index c34c88f88c..377e80c849 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -420,23 +420,31 @@ class Model(BaseModel): assert m.dict(include={'a', 'b', 'c'}, exclude={'b'}, exclude_defaults=True) == {'a': 1} assert m.dict(include={'a', 'b', 'c'}, exclude={'a', 'c'}, exclude_defaults=True) == {'b': 2} + # abstract set + assert m.dict(include={'a': 1}.keys()) == {'a': 1} + assert m.dict(exclude={'a': 1}.keys()) == {'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 7} + + assert m.dict(include={'a': 1}.keys(), exclude_unset=True) == {'a': 1} + assert m.dict(exclude={'a': 1}.keys(), exclude_unset=True) == {'b': 2, 'e': 5, 'f': 7} + def test_skip_defaults_deprecated(): class Model(BaseModel): x: int + b: int = 2 m = Model(x=1) match = r'Model.dict\(\): "skip_defaults" is deprecated and replaced by "exclude_unset"' with pytest.warns(DeprecationWarning, match=match): - assert m.dict(skip_defaults=True) + assert m.dict(skip_defaults=True) == m.dict(exclude_unset=True) with pytest.warns(DeprecationWarning, match=match): - assert m.dict(skip_defaults=False) + assert m.dict(skip_defaults=False) == m.dict(exclude_unset=False) match = r'Model.json\(\): "skip_defaults" is deprecated and replaced by "exclude_unset"' with pytest.warns(DeprecationWarning, match=match): - assert m.json(skip_defaults=True) + assert m.json(skip_defaults=True) == m.json(exclude_unset=True) with pytest.warns(DeprecationWarning, match=match): - assert m.json(skip_defaults=False) + assert m.json(skip_defaults=False) == m.json(exclude_unset=False) def test_advanced_exclude():