From 6125b38a2a7f5cce83f64109265889ccb9d2ff9b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 21 Jun 2017 15:35:44 +0100 Subject: [PATCH] tweak copy and add fields copy test --- pydantic/main.py | 14 ++++++-------- tests/test_construction.py | 3 +++ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pydantic/main.py b/pydantic/main.py index cfd77086d17..421eaa06d8a 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -136,23 +136,21 @@ def construct(cls, **values): def copy(self, *, include: Set[str]=None, exclude: Set[str]=None, update: Dict[str, Any]=None): """ - Duplicate a model, optionally choose which fields to include, exclude and change - :param include: fields to include in new model. + Duplicate a model, optionally choose which fields to include, exclude and change. + + :param include: fields to include in new model :param exclude: fields to exclude from new model, as with values this takes precedence over include :param update: values to change/add in the new model. Note: the data is not validated before creating - the new model: you should trust this data. + the new model: you should trust this data :return: new model instance """ if include is None and exclude is None and update is None: - # skip constructing values if no arguments are set + # skip constructing values if no arguments are passed v = self.__values__ else: exclude = exclude or set() v = { - **{ - k: v for k, v in self.__values__.items() - if k not in exclude and (not include or k in include) - }, + **{k: v for k, v in self.__values__.items() if k not in exclude and (not include or k in include)}, **(update or {}) } return self.__class__.construct(**v) diff --git a/tests/test_construction.py b/tests/test_construction.py index 51b9431c8f8..bc7df6f6387 100644 --- a/tests/test_construction.py +++ b/tests/test_construction.py @@ -32,6 +32,7 @@ def test_simple_copy(): assert m.a == m2.a == 24 assert m.b == m2.b == 10 assert m == m2 + assert m.__fields__ == m2.__fields__ class ModelTwo(BaseModel): @@ -97,6 +98,7 @@ def test_simple_pickle(): assert m is not m2 assert tuple(m) == (('a', 24.0), ('b', 10)) assert tuple(m2) == (('a', 24.0), ('b', 10)) + assert m.__fields__ == m2.__fields__ def test_recursive_pickle(): @@ -106,3 +108,4 @@ def test_recursive_pickle(): assert m.d.a == 123.45 assert m2.d.a == 123.45 + assert m.__fields__ == m2.__fields__