Skip to content

Commit

Permalink
Minor refactoring -- no behavioral changes (#8236)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Nov 27, 2023
1 parent 22e6444 commit 202d379
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 174 deletions.
93 changes: 93 additions & 0 deletions pydantic/_internal/_constructor_signature_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable

from ._config import ConfigWrapper
from ._utils import is_valid_identifier

if TYPE_CHECKING:
from ..fields import FieldInfo


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
parameter_post_processor: Callable[[Parameter], Parameter] = lambda x: x,
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
parameter_post_processor: Optional additional processing for parameter
Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = parameter_post_processor(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = parameter_post_processor(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = parameter_post_processor(var_kw.replace(name=var_kw_name))

return Signature(parameters=list(merged_params.values()), return_annotation=None)
33 changes: 10 additions & 23 deletions pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import warnings
from functools import partial, wraps
from inspect import Parameter, Signature
from inspect import Parameter
from typing import Any, Callable, ClassVar

from pydantic_core import (
Expand All @@ -23,9 +23,9 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._config import ConfigWrapper
from ._constructor_signature_generators import generate_pydantic_signature
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand Down Expand Up @@ -125,7 +125,12 @@ def complete_dataclass(
)

# This needs to be called before we change the __init__
sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore
sig = generate_pydantic_signature(
init=cls.__init__,
fields=cls.__pydantic_fields__, # type: ignore
config_wrapper=config_wrapper,
parameter_post_processor=process_param_defaults,
)

# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -188,7 +193,7 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None:


def process_param_defaults(param: Parameter) -> Parameter:
"""Custom processing where the parameter default is of type FieldInfo
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
Args:
param (Parameter): The parameter
Expand Down Expand Up @@ -226,24 +231,6 @@ def process_param_defaults(param: Parameter) -> Parameter:
return param


def generate_dataclass_signature(
cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for a pydantic dataclass.
Args:
cls: The dataclass.
fields: The model fields.
config_wrapper: The config wrapper instance.
Returns:
The dataclass signature.
"""
return generate_pydantic_signature(
init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults
)


def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
Expand Down
86 changes: 1 addition & 85 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._utils import is_valid_identifier, lenient_issubclass
from ._utils import lenient_issubclass

if TYPE_CHECKING:
from ..fields import ComputedFieldInfo, FieldInfo
Expand Down Expand Up @@ -2119,87 +2119,3 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x,
) -> inspect.Signature:
"""Generate signature for a pydantic class generated by inheriting from BaseModel or
using the dataclass annotation
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
post_process_parameter: Optional additional processing for parameter
Returns:
The dataclass/BaseModel subclass signature.
"""
from itertools import islice

present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False

for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = post_process_parameter(param)

if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name

if field_name in merged_params or param_name in merged_params:
continue

if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue

kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = post_process_parameter(
Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs)
)

if config_wrapper.extra == 'allow':
use_var_kw = True

if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name

# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name))

return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None)
24 changes: 4 additions & 20 deletions pydantic/_internal/_model_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._constructor_signature_generators import generate_pydantic_signature
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema, generate_pydantic_signature
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
Expand All @@ -29,8 +30,6 @@
from ._validate_call import ValidateCallWrapper

if typing.TYPE_CHECKING:
from inspect import Signature

from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..main import BaseModel
Expand Down Expand Up @@ -536,27 +535,12 @@ def complete_model_class(

# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute(
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
'__signature__',
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
)
return True


def generate_model_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for model based on its fields.
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
Returns:
The model signature.
"""
return generate_pydantic_signature(init, fields, config_wrapper)


class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Expand Down

0 comments on commit 202d379

Please sign in to comment.