Skip to content

Commit

Permalink
tweak copy and add fields copy test
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jun 21, 2017
1 parent dce9dee commit 6125b38
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 6 additions & 8 deletions pydantic/main.py
Expand Up @@ -136,23 +136,21 @@ def construct(cls, **values):

def copy(self, *, include: Set[str]=None, exclude: Set[str]=None, update: Dict[str, Any]=None):
"""
Duplicate a model, optionally choose which fields to include, exclude and change
:param include: fields to include in new model.
Duplicate a model, optionally choose which fields to include, exclude and change.
:param include: fields to include in new model
:param exclude: fields to exclude from new model, as with values this takes precedence over include
:param update: values to change/add in the new model. Note: the data is not validated before creating
the new model: you should trust this data.
the new model: you should trust this data
:return: new model instance
"""
if include is None and exclude is None and update is None:
# skip constructing values if no arguments are set
# skip constructing values if no arguments are passed
v = self.__values__
else:
exclude = exclude or set()
v = {
**{
k: v for k, v in self.__values__.items()
if k not in exclude and (not include or k in include)
},
**{k: v for k, v in self.__values__.items() if k not in exclude and (not include or k in include)},
**(update or {})
}
return self.__class__.construct(**v)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_construction.py
Expand Up @@ -32,6 +32,7 @@ def test_simple_copy():
assert m.a == m2.a == 24
assert m.b == m2.b == 10
assert m == m2
assert m.__fields__ == m2.__fields__


class ModelTwo(BaseModel):
Expand Down Expand Up @@ -97,6 +98,7 @@ def test_simple_pickle():
assert m is not m2
assert tuple(m) == (('a', 24.0), ('b', 10))
assert tuple(m2) == (('a', 24.0), ('b', 10))
assert m.__fields__ == m2.__fields__


def test_recursive_pickle():
Expand All @@ -106,3 +108,4 @@ def test_recursive_pickle():

assert m.d.a == 123.45
assert m2.d.a == 123.45
assert m.__fields__ == m2.__fields__

0 comments on commit 6125b38

Please sign in to comment.