Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactoring -- no behavioral changes #8236

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 93 additions & 0 deletions pydantic/_internal/_constructor_signature_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable

from ._config import ConfigWrapper
from ._utils import is_valid_identifier

if TYPE_CHECKING:
from ..fields import FieldInfo


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
parameter_post_processor: Callable[[Parameter], Parameter] = lambda x: x,
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
parameter_post_processor: Optional additional processing for parameter

Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = parameter_post_processor(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = parameter_post_processor(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = parameter_post_processor(var_kw.replace(name=var_kw_name))

return Signature(parameters=list(merged_params.values()), return_annotation=None)
33 changes: 10 additions & 23 deletions pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from functools import partial, wraps
from inspect import Parameter, Signature
from inspect import Parameter
from typing import Any, Callable, ClassVar

from pydantic_core import (
Expand All @@ -23,9 +23,9 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._config import ConfigWrapper
from ._constructor_signature_generators import generate_pydantic_signature
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand Down Expand Up @@ -125,7 +125,12 @@ def complete_dataclass(
)

# This needs to be called before we change the __init__
sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore
sig = generate_pydantic_signature(
init=cls.__init__,
fields=cls.__pydantic_fields__, # type: ignore
config_wrapper=config_wrapper,
parameter_post_processor=process_param_defaults,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -188,7 +193,7 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:


def process_param_defaults(param: Parameter) -> Parameter:
"""Custom processing where the parameter default is of type FieldInfo
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.

Args:
param (Parameter): The parameter
Expand Down Expand Up @@ -226,24 +231,6 @@ def process_param_defaults(param: Parameter) -> Parameter:
return param


def generate_dataclass_signature(
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for a pydantic dataclass.

Args:
cls: The dataclass.
fields: The model fields.
config_wrapper: The config wrapper instance.

Returns:
The dataclass signature.
"""
return generate_pydantic_signature(
init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults
)


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.

Expand Down
86 changes: 1 addition & 85 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._utils import is_valid_identifier, lenient_issubclass
from ._utils import lenient_issubclass

if TYPE_CHECKING:
from ..fields import ComputedFieldInfo, FieldInfo
Expand Down Expand Up @@ -2119,87 +2119,3 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x,
) -> inspect.Signature:
"""Generate signature for a pydantic class generated by inheriting from BaseModel or
using the dataclass annotation

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
post_process_parameter: Optional additional processing for parameter

Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = post_process_parameter(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = post_process_parameter(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name))

return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None)
24 changes: 4 additions & 20 deletions pydantic/_internal/_model_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._constructor_signature_generators import generate_pydantic_signature
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand All @@ -29,8 +30,6 @@
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
from inspect import Signature

from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..main import BaseModel
Expand Down Expand Up @@ -536,27 +535,12 @@ def complete_model_class(

# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute(
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
'__signature__',
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
)
return True


def generate_model_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for model based on its fields.

Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.

Returns:
The model signature.
"""
return generate_pydantic_signature(init, fields, config_wrapper)


class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.

Expand Down