Skip to content

Commit

Permalink
Tracking for setting attributes (#389)
Browse files Browse the repository at this point in the history
fix #378

* Tracking for setting attributes

* Fixes accidental leak of fields

* Allows defaults fields to be recursively set

* Docs and history for skip_defaults

* Mypy fix on calculate keys

* Update pydantic/main.py

Co-Authored-By: dgasmith <dgasmith@icloud.com>

* Update pydantic/main.py

Co-Authored-By: dgasmith <dgasmith@icloud.com>

* Update HISTORY.rst

Co-Authored-By: dgasmith <dgasmith@icloud.com>

* Cleanup pass based off review

* Simplifies constructors based on feedback

* Makes mypy happy with exlicit KeysView

* SetOrKeys and faster key search

* Formats files once more

* add tests for dict, pickle and construct

* fixes for dict, pickle and construct

* correct field_set for extra.ignore

* Fixes format
  • Loading branch information
dgasmith authored and samuelcolvin committed Feb 13, 2019
1 parent baade9a commit 96e3e74
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 50 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ v0.20.0 (unreleased)
* **breaking change** (maybe): more sophisticated argument parsing for validators, any subset of
``values``, ``config`` and ``field`` is now permitted, eg. ``(cls, value, field)``,
however the variadic key word argument ("``**kwargs``") **must** be called ``kwargs``, #388 by @samuelcolvin
* Adds ``skip_defaults`` argument to ``BaseModel.dict()`` to allow skipping of fields that were not
explicitly set, #389 by @dgasmith

v0.19.0 (2019-02-04)
....................
Expand Down
4 changes: 4 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ converted to dicts, ``copy`` allows models to be duplicated, this is particularl
respectively. ``copy`` accepts extra keyword arguments, ``update``, which accepts a ``dict`` mapping attributes
to new values that will be applied as the model is duplicated and ``deep`` to make a deep copy of the model.

``dict`` and ``json`` take the optional ``skip_defaults`` keyword argument which will skip attributes that were
not explicitly set. This is useful to reduce the serialized size of models thats have many default fields that
are not often changed.

.. literalinclude:: examples/copy_dict.py

Serialisation
Expand Down
130 changes: 91 additions & 39 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@

from .class_validators import ValidatorGroup, extract_validators, inherit_validators
from .error_wrappers import ErrorWrapper, ValidationError
from .errors import ConfigError, ExtraError, MissingError
from .errors import ConfigError, DictError, ExtraError, MissingError
from .fields import Field
from .json import custom_pydantic_encoder, pydantic_encoder
from .parse import Protocol, load_file, load_str_bytes
from .schema import model_schema
from .types import StrBytes
from .utils import AnyCallable, AnyType, ForwardRef, resolve_annotations, truncate, validate_field_name
from .validators import dict_validator
from .utils import (
AnyCallable,
AnyType,
ForwardRef,
change_exception,
resolve_annotations,
truncate,
validate_field_name,
)

if TYPE_CHECKING: # pragma: no cover
from .types import CallableGenerator
Expand All @@ -42,6 +49,7 @@
TupleGenerator = Generator[Tuple[str, Any], None, None]
DictStrAny = Dict[str, Any]
ConfigType = Type['BaseConfig']
DictAny = Dict[Any, Any]


class Extra(str, Enum):
Expand Down Expand Up @@ -203,15 +211,21 @@ class BaseModel(metaclass=MetaModel):
__validators__: Dict[str, AnyCallable] = {}
__config__: Type[BaseConfig] = BaseConfig
_json_encoder: Callable[[Any], Any] = lambda x: x
_schema_cache: Dict[Any, Any] = {}
_schema_cache: 'DictAny' = {}

Config = BaseConfig
__slots__ = ('__values__',)
__slots__ = ('__values__', '__fields_set__')

def __init__(self, **data: Any) -> None:
if TYPE_CHECKING: # pragma: no cover
self.__values__: Dict[str, Any] = {}
self.__setstate__(self._process_values(data))
self.__fields_set__: Set[str] = set()
object.__setattr__(self, '__values__', self._process_values(data))
if self.__config__.extra is Extra.allow:
fields_set = set(data.keys())
else:
fields_set = data.keys() & self.__values__.keys() # type: ignore
object.__setattr__(self, '__fields_set__', fields_set)

@no_type_check
def __getattr__(self, name):
Expand All @@ -232,27 +246,34 @@ def __setattr__(self, name, value):
raise ValidationError([error_])
else:
self.__values__[name] = value_
self.__fields_set__.add(name)
else:
self.__values__[name] = value
self.__fields_set__.add(name)

def __getstate__(self) -> Dict[Any, Any]:
return self.__values__
def __getstate__(self) -> 'DictAny':
return {'__values__': self.__values__, '__fields_set__': self.__fields_set__}

def __setstate__(self, state: Dict[Any, Any]) -> None:
object.__setattr__(self, '__values__', state)
def __setstate__(self, state: 'DictAny') -> None:
object.__setattr__(self, '__values__', state['__values__'])
object.__setattr__(self, '__fields_set__', state['__fields_set__'])

def dict(self, *, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False) -> 'DictStrAny':
def dict(
self, *, include: Set[str] = None, exclude: Set[str] = None, by_alias: bool = False, skip_defaults: bool = False
) -> 'DictStrAny':
"""
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
"""
get_key = self._get_key_factory(by_alias)
get_key = partial(get_key, self.fields)

return {
get_key(k): v
for k, v in self._iter(by_alias=by_alias)
if k not in exclude and (not include or k in include)
}
return_keys = self._calculate_keys(include=include, exclude=exclude, skip_defaults=skip_defaults)
if return_keys is None:
return {get_key(k): v for k, v in self._iter(by_alias=by_alias, skip_defaults=skip_defaults)}
else:
return {
get_key(k): v for k, v in self._iter(by_alias=by_alias, skip_defaults=skip_defaults) if k in return_keys
}

def _get_key_factory(self, by_alias: bool) -> Callable[..., str]:
if by_alias:
Expand All @@ -263,9 +284,10 @@ def _get_key_factory(self, by_alias: bool) -> Callable[..., str]:
def json(
self,
*,
include: Set[str] = set(),
exclude: Set[str] = set(),
include: Set[str] = None,
exclude: Set[str] = None,
by_alias: bool = False,
skip_defaults: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
**dumps_kwargs: Any,
) -> str:
Expand All @@ -276,11 +298,13 @@ def json(
"""
encoder = cast(Callable[[Any], Any], encoder or self._json_encoder)
return json.dumps(
self.dict(include=include, exclude=exclude, by_alias=by_alias), default=encoder, **dumps_kwargs
self.dict(include=include, exclude=exclude, by_alias=by_alias, skip_defaults=skip_defaults),
default=encoder,
**dumps_kwargs,
)

@classmethod
def parse_obj(cls, obj: Dict[Any, Any]) -> 'BaseModel':
def parse_obj(cls, obj: 'DictAny') -> 'BaseModel':
if not isinstance(obj, dict):
exc = TypeError(f'{cls.__name__} expected dict not {type(obj).__name__}')
raise ValidationError([ErrorWrapper(exc, loc='__obj__')])
Expand Down Expand Up @@ -318,13 +342,14 @@ def parse_file(
return cls.parse_obj(obj)

@classmethod
def construct(cls, **values: Any) -> 'BaseModel':
def construct(cls, values: 'DictAny', fields_set: Set[str]) -> 'BaseModel':
"""
Creates a new model and set __values__ without any validation, thus values should already be trusted.
Chances are you don't want to use this method directly.
"""
m = cls.__new__(cls)
m.__setstate__(values)
object.__setattr__(m, '__values__', values)
object.__setattr__(m, '__fields_set__', fields_set)
return m

def copy(
Expand All @@ -344,14 +369,16 @@ def copy(
# skip constructing values if no arguments are passed
v = self.__values__
else:
exclude = exclude or set()
v = {
**{k: v for k, v in self.__values__.items() if k not in exclude and (not include or k in include)},
**(update or {}),
}
return_keys = self._calculate_keys(include=include, exclude=exclude, skip_defaults=False)
if return_keys:
v = {**{k: v for k, v in self.__values__.items() if k in return_keys}, **(update or {})}
else:
v = {**self.__values__, **(update or {})}

if deep:
v = deepcopy(v)
return self.__class__.construct(**v)
m = self.__class__.construct(v, self.__fields_set__.copy())
return m

@property
def fields(self) -> Dict[str, Field]:
Expand All @@ -374,29 +401,34 @@ def schema_json(cls, *, by_alias: bool = True, **dumps_kwargs: Any) -> str:

@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield dict_validator
yield cls.validate

@classmethod
def validate(cls, value: 'DictStrAny') -> 'BaseModel':
return cls(**value)
def validate(cls, value: Union['DictStrAny', 'BaseModel']) -> 'BaseModel':
if isinstance(value, dict):
return cls(**value)
elif isinstance(value, BaseModel):
return value.copy()
else:
with change_exception(DictError, TypeError, ValueError):
return cls(**dict(value))

def _process_values(self, input_data: Any) -> 'DictStrAny':
# (casting here is slow so use ignore)
return validate_model(self, input_data) # type: ignore

@classmethod
def _get_value(cls, v: Any, by_alias: bool) -> Any:
def _get_value(cls, v: Any, by_alias: bool, skip_defaults: bool) -> Any:
if isinstance(v, BaseModel):
return v.dict(by_alias=by_alias)
return v.dict(by_alias=by_alias, skip_defaults=skip_defaults)
elif isinstance(v, list):
return [cls._get_value(v_, by_alias=by_alias) for v_ in v]
return [cls._get_value(v_, by_alias=by_alias, skip_defaults=skip_defaults) for v_ in v]
elif isinstance(v, dict):
return {k_: cls._get_value(v_, by_alias=by_alias) for k_, v_ in v.items()}
return {k_: cls._get_value(v_, by_alias=by_alias, skip_defaults=skip_defaults) for k_, v_ in v.items()}
elif isinstance(v, set):
return {cls._get_value(v_, by_alias=by_alias) for v_ in v}
return {cls._get_value(v_, by_alias=by_alias, skip_defaults=skip_defaults) for v_ in v}
elif isinstance(v, tuple):
return tuple(cls._get_value(v_, by_alias=by_alias) for v_ in v)
return tuple(cls._get_value(v_, by_alias=by_alias, skip_defaults=skip_defaults) for v_ in v)
else:
return v

Expand All @@ -418,9 +450,29 @@ def __iter__(self) -> 'AnyGenerator':
"""
yield from self._iter()

def _iter(self, by_alias: bool = False) -> 'TupleGenerator':
def _iter(self, by_alias: bool = False, skip_defaults: bool = False) -> 'TupleGenerator':
for k, v in self.__values__.items():
yield k, self._get_value(v, by_alias=by_alias)
yield k, self._get_value(v, by_alias=by_alias, skip_defaults=skip_defaults)

def _calculate_keys(
self, include: Set[str] = None, exclude: Optional[Set[str]] = None, skip_defaults: bool = False
) -> Optional[Set[str]]:

if include is None and exclude is None and skip_defaults is False:
return None

if skip_defaults:
keys = self.__fields_set__.copy()
else:
keys = set(self.__values__.keys())

if include:
keys &= include

if exclude:
keys -= exclude

return keys

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
Expand Down
21 changes: 18 additions & 3 deletions tests/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@


class Model(BaseModel):
a: float = ...
a: float
b: int = 10


def test_simple_construct():
m = Model.construct(a=40, b=10)
m = Model.construct(dict(a=40, b=10), {'a', 'b'})
assert m.a == 40
assert m.b == 10


def test_construct_missing():
m = Model.construct(a='not a float')
m = Model.construct(dict(a='not a float'), {'a'})
assert m.a == 'not a float'
with pytest.raises(AttributeError) as exc_info:
print(m.b)
Expand Down Expand Up @@ -112,6 +112,14 @@ def test_copy_update():
assert m != m2


def test_copy_set_fields():
m = ModelTwo(a=24, d=Model(a='12'))
m2 = m.copy()

assert m.dict(skip_defaults=True) == {'a': 24.0, 'd': {'a': 12}}
assert m.dict(skip_defaults=True) == m2.dict(skip_defaults=True)


def test_simple_pickle():
m = Model(a='24')
b = pickle.dumps(m)
Expand Down Expand Up @@ -150,3 +158,10 @@ class Config:
assert str(m2) == 'Model a=40 b=12'
with pytest.raises(TypeError):
m2.b = 13


def test_pickle_fields_set():
m = Model(a=24)
assert m.dict(skip_defaults=True) == {'a': 24}
m2 = pickle.loads(pickle.dumps(m))
assert m2.dict(skip_defaults=True) == {'a': 24}
73 changes: 73 additions & 0 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,79 @@ class Model(BaseModel):
assert m.dict(include={'a', 'b'}, exclude={'a'}) == {'b': 2}


def test_include_exclude_default():
class Model(BaseModel):
a: int
b: int
c: int = 3
d: int = 4

m = Model(a=1, b=2)
assert m.dict() == {'a': 1, 'b': 2, 'c': 3, 'd': 4}
assert m.__fields_set__ == {'a', 'b'}
assert m.dict(skip_defaults=True) == {'a': 1, 'b': 2}

assert m.dict(include={'a'}, skip_defaults=True) == {'a': 1}
assert m.dict(include={'c'}, skip_defaults=True) == {}

assert m.dict(exclude={'a'}, skip_defaults=True) == {'b': 2}
assert m.dict(exclude={'c'}, skip_defaults=True) == {'a': 1, 'b': 2}

assert m.dict(include={'a', 'b', 'c'}, exclude={'b'}, skip_defaults=True) == {'a': 1}
assert m.dict(include={'a', 'b', 'c'}, exclude={'a', 'c'}, skip_defaults=True) == {'b': 2}


def test_field_set_ignore_extra():
class Model(BaseModel):
a: int
b: int
c: int = 3

class Config:
extra = Extra.ignore

m = Model(a=1, b=2)
assert m.dict() == {'a': 1, 'b': 2, 'c': 3}
assert m.__fields_set__ == {'a', 'b'}
assert m.dict(skip_defaults=True) == {'a': 1, 'b': 2}

m2 = Model(a=1, b=2, d=4)
assert m2.dict() == {'a': 1, 'b': 2, 'c': 3}
assert m2.__fields_set__ == {'a', 'b'}
assert m2.dict(skip_defaults=True) == {'a': 1, 'b': 2}


def test_field_set_allow_extra():
class Model(BaseModel):
a: int
b: int
c: int = 3

class Config:
extra = Extra.allow

m = Model(a=1, b=2)
assert m.dict() == {'a': 1, 'b': 2, 'c': 3}
assert m.__fields_set__ == {'a', 'b'}
assert m.dict(skip_defaults=True) == {'a': 1, 'b': 2}

m2 = Model(a=1, b=2, d=4)
assert m2.dict() == {'a': 1, 'b': 2, 'c': 3, 'd': 4}
assert m2.__fields_set__ == {'a', 'b', 'd'}
assert m2.dict(skip_defaults=True) == {'a': 1, 'b': 2, 'd': 4}


def test_field_set_field_name():
class Model(BaseModel):
a: int
field_set: int
b: int = 3

assert Model(a=1, field_set=2).dict() == {'a': 1, 'field_set': 2, 'b': 3}
assert Model(a=1, field_set=2).dict(skip_defaults=True) == {'a': 1, 'field_set': 2}
assert Model.construct(dict(a=1, field_set=3), {'a', 'field_set'}).dict() == {'a': 1, 'field_set': 3}


def test_values_order():
class Model(BaseModel):
a: int = 1
Expand Down
Loading

0 comments on commit 96e3e74

Please sign in to comment.