Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(fields): handle properly default value for type Callable #2094

Merged
merged 8 commits into from Feb 11, 2021
1 change: 1 addition & 0 deletions changes/1596-PrettyWood.md
@@ -0,0 +1 @@
Handle properly default value for fields of type `Callable`
7 changes: 4 additions & 3 deletions pydantic/main.py
Expand Up @@ -198,7 +198,8 @@ def validate_custom_root_type(fields: Dict[str, ModelField]) -> None:
raise ValueError('__root__ cannot be mixed with other fields')


UNTOUCHED_TYPES = FunctionType, property, type, classmethod, staticmethod
FIELD_DEFAULT_VALUE_UNTOUCHED_TYPES = property, type, classmethod, staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this name shorter. I think we need to describe what's going in here in a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment and renamed it to ANNOTATED_FIELD_UNTOUCHED_TYPES. Still a bit long but I couldn't find an explicit name shorter than that.

UNTOUCHED_TYPES = (FunctionType,) + FIELD_DEFAULT_VALUE_UNTOUCHED_TYPES

# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we need to add this extra
# (somewhat hacky) boolean to keep track of whether we've created the `BaseModel` class yet, and therefore whether it's
Expand Down Expand Up @@ -245,7 +246,6 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
class_vars = set()
if (namespace.get('__module__'), namespace.get('__qualname__')) != ('pydantic.main', 'BaseModel'):
annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None))
untouched_types = UNTOUCHED_TYPES + config.keep_untouched
# annotation only fields need to come first in fields
for ann_name, ann_type in annotations.items():
if is_classvar(ann_type):
Expand All @@ -254,7 +254,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
validate_field_name(bases, ann_name)
value = namespace.get(ann_name, Undefined)
if (
isinstance(value, untouched_types)
isinstance(value, FIELD_DEFAULT_VALUE_UNTOUCHED_TYPES)
and ann_type != PyObject
and not lenient_issubclass(get_origin(ann_type), Type)
):
Expand All @@ -269,6 +269,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
elif ann_name not in namespace and config.underscore_attrs_are_private:
private_attributes[ann_name] = PrivateAttr()

untouched_types = UNTOUCHED_TYPES + config.keep_untouched
for var_name, value in namespace.items():
can_be_changed = var_name not in class_vars and not isinstance(value, untouched_types)
if isinstance(value, ModelPrivateAttr):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_main.py
@@ -1,4 +1,5 @@
import sys
from dataclasses import dataclass as stdlib_dataclass, field
from enum import Enum
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
from uuid import UUID, uuid4
Expand All @@ -18,6 +19,7 @@
root_validator,
validator,
)
from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.typing import Literal


Expand Down Expand Up @@ -1374,3 +1376,47 @@ class M(BaseModel):
a: int

get_type_hints(M.__config__)


def foo(arg1, arg2):
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
return arg1, arg2


@pydantic_dataclass
PrettyWood marked this conversation as resolved.
Show resolved Hide resolved
class HasCallablesDC:
non_default_callable: Callable
default_callable: Callable = lambda x: foo(x, 'default')
default_callable_factory: Callable = field(default=lambda x: foo(x, 'factory'))


class HasCallablesModel(BaseModel):
non_default_callable: Callable
default_callable: Callable = lambda x: foo(x, 'default')
default_callable_factory: Callable = Field(default_factory=lambda: lambda x: foo(x, 'factory'))


@stdlib_dataclass
class HasCallablesStdlibDC:
non_default_callable: Callable
default_callable: Callable = lambda x: foo(x, 'default')
default_callable_factory: Callable = field(default_factory=lambda: lambda x: foo(x, 'factory'))


@pytest.mark.parametrize('cls', [HasCallablesModel, HasCallablesDC])
def test_pydantic_callable_field(cls):
"""pydantic callable fields behaviour should be the same as stdlib dataclass"""

def non_default_callable(x):
return foo(x, 'nondefault')

a1 = cls(non_default_callable=non_default_callable)
a2 = HasCallablesStdlibDC(non_default_callable=non_default_callable)

# call non_default
assert a1.non_default_callable('hello') == a2.non_default_callable('hello')

# call default_factory
assert a1.default_callable_factory('hello') == a2.default_callable_factory('hello')

# call default
assert a1.default_callable('hello') == a2.default_callable('hello')