diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index d1d1c5fa0c..17240571e5 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -6,7 +6,7 @@ import typing import warnings from functools import partial, wraps -from inspect import Parameter, Signature, signature +from inspect import Parameter, Signature from typing import Any, Callable, ClassVar from pydantic_core import ( @@ -23,8 +23,9 @@ from ..plugin._schema_validator import create_schema_validator from ..warnings import PydanticDeprecatedSince20 from . import _config, _decorators, _typing_extra +from ._config import ConfigWrapper from ._fields import collect_dataclass_fields -from ._generate_schema import GenerateSchema +from ._generate_schema import GenerateSchema, generate_pydantic_signature from ._generics import get_standard_typevars_map from ._mock_val_ser import set_dataclass_mocks from ._schema_generation_shared import CallbackGetCoreSchemaHandler @@ -123,19 +124,20 @@ def complete_dataclass( typevars_map, ) - # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied. + # This needs to be called before we change the __init__ + sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore + # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied. def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None: __tracebackhide__ = True s = __dataclass_self__ s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s) __init__.__qualname__ = f'{cls.__qualname__}.__init__' - sig = generate_dataclass_signature(cls) + cls.__init__ = __init__ # type: ignore - cls.__signature__ = sig # type: ignore cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore - + cls.__signature__ = sig # type: ignore get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None) try: if get_core_schema: @@ -185,54 +187,61 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None: return True -def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature: - """Generate signature for a pydantic dataclass. +def process_param_defaults(param: Parameter) -> Parameter: + """Custom processing where the parameter default is of type FieldInfo - This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses. - If we change this eventually, we should make this function's logic more closely mirror that from - `pydantic._internal._model_construction.generate_model_signature`. + Args: + param (Parameter): The parameter + + Returns: + Parameter: The custom processed parameter + """ + param_default = param.default + if isinstance(param_default, FieldInfo): + annotation = param.annotation + # Replace the annotation if appropriate + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if annotation == 'Any': + annotation = Any + + # Replace the field name with the alias if present + name = param.name + alias = param_default.alias + validation_alias = param_default.validation_alias + if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias): + name = alias + elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias): + name = validation_alias + + # Replace the field default + default = param_default.default + if default is PydanticUndefined: + if param_default.default_factory is PydanticUndefined: + default = inspect.Signature.empty + else: + # this is used by dataclasses to indicate a factory exists: + default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore + return param.replace(annotation=annotation, name=name, default=default) + return param + + +def generate_dataclass_signature( + cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper +) -> Signature: + """Generate signature for a pydantic dataclass. Args: cls: The dataclass. + fields: The model fields. + config_wrapper: The config wrapper instance. Returns: - The signature. + The dataclass signature. """ - sig = signature(cls) - final_params: dict[str, Parameter] = {} - - for param in sig.parameters.values(): - param_default = param.default - if isinstance(param_default, FieldInfo): - annotation = param.annotation - # Replace the annotation if appropriate - # inspect does "clever" things to show annotations as strings because we have - # `from __future__ import annotations` in main, we don't want that - if annotation == 'Any': - annotation = Any - - # Replace the field name with the alias if present - name = param.name - alias = param_default.alias - validation_alias = param_default.validation_alias - if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias): - name = alias - elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias): - name = validation_alias - - # Replace the field default - default = param_default.default - if default is PydanticUndefined: - if param_default.default_factory is PydanticUndefined: - default = inspect.Signature.empty - else: - # this is used by dataclasses to indicate a factory exists: - default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore - - param = param.replace(annotation=annotation, name=name, default=default) - final_params[param.name] = param - - return Signature(parameters=list(final_params.values()), return_annotation=None) + return generate_pydantic_signature( + init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults + ) def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 15f16bc1eb..06e58eccff 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -82,7 +82,7 @@ CallbackGetCoreSchemaHandler, ) from ._typing_extra import is_finalvar -from ._utils import lenient_issubclass +from ._utils import is_valid_identifier, lenient_issubclass if TYPE_CHECKING: from ..main import BaseModel @@ -2085,3 +2085,87 @@ def get(self) -> str | None: return self._stack[-1] else: return None + + +def generate_pydantic_signature( + init: Callable[..., None], + fields: dict[str, FieldInfo], + config_wrapper: ConfigWrapper, + post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x, +) -> inspect.Signature: + """Generate signature for a pydantic class generated by inheriting from BaseModel or + using the dataclass annotation + + Args: + init: The class init. + fields: The model fields. + config_wrapper: The config wrapper instance. + post_process_parameter: Optional additional processing for parameter + + Returns: + The dataclass/BaseModel subclass signature. + """ + from itertools import islice + + present_params = signature(init).parameters.values() + merged_params: dict[str, Parameter] = {} + var_kw = None + use_var_kw = False + + for param in islice(present_params, 1, None): # skip self arg + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if param.annotation == 'Any': + param = param.replace(annotation=Any) + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = post_process_parameter(param) + + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config_wrapper.populate_by_name + for field_name, field in fields.items(): + # when alias is a str it should be used for signature generation + if isinstance(field.alias, str): + param_name = field.alias + else: + param_name = field_name + + if field_name in merged_params or param_name in merged_params: + continue + + if not is_valid_identifier(param_name): + if allow_names and is_valid_identifier(field_name): + param_name = field_name + else: + use_var_kw = True + continue + + kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} + merged_params[param_name] = post_process_parameter( + Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs) + ) + + if config_wrapper.extra == 'allow': + use_var_kw = True + + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' + else: + # else start from var_kw + var_kw_name = var_kw.name + + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name)) + + return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 9cf88e3abb..d5d0225c8d 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -24,12 +24,12 @@ get_attribute_from_bases, ) from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name -from ._generate_schema import GenerateSchema +from ._generate_schema import GenerateSchema, generate_pydantic_signature from ._generics import PydanticGenericMetadata, get_model_typevars_map from ._mock_val_ser import MockValSer, set_model_mocks from ._schema_generation_shared import CallbackGetCoreSchemaHandler from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace -from ._utils import ClassAttribute, is_valid_identifier +from ._utils import ClassAttribute from ._validate_call import ValidateCallWrapper if typing.TYPE_CHECKING: @@ -518,71 +518,7 @@ def generate_model_signature( Returns: The model signature. """ - from inspect import Parameter, Signature, signature - from itertools import islice - - present_params = signature(init).parameters.values() - merged_params: dict[str, Parameter] = {} - var_kw = None - use_var_kw = False - - for param in islice(present_params, 1, None): # skip self arg - # inspect does "clever" things to show annotations as strings because we have - # `from __future__ import annotations` in main, we don't want that - if param.annotation == 'Any': - param = param.replace(annotation=Any) - if param.kind is param.VAR_KEYWORD: - var_kw = param - continue - merged_params[param.name] = param - - if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through - allow_names = config_wrapper.populate_by_name - for field_name, field in fields.items(): - # when alias is a str it should be used for signature generation - if isinstance(field.alias, str): - param_name = field.alias - else: - param_name = field_name - - if field_name in merged_params or param_name in merged_params: - continue - - if not is_valid_identifier(param_name): - if allow_names and is_valid_identifier(field_name): - param_name = field_name - else: - use_var_kw = True - continue - - kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} - merged_params[param_name] = Parameter( - param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs - ) - - if config_wrapper.extra == 'allow': - use_var_kw = True - - if var_kw and use_var_kw: - # Make sure the parameter for extra kwargs - # does not have the same name as a field - default_model_signature = [ - ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), - ('data', Parameter.VAR_KEYWORD), - ] - if [(p.name, p.kind) for p in present_params] == default_model_signature: - # if this is the standard model signature, use extra_data as the extra args name - var_kw_name = 'extra_data' - else: - # else start from var_kw - var_kw_name = var_kw.name - - # generate a name that's definitely unique - while var_kw_name in fields: - var_kw_name += '_' - merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) - - return Signature(parameters=list(merged_params.values()), return_annotation=None) + return generate_pydantic_signature(init, fields, config_wrapper) class _PydanticWeakRef: diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 0f3d1d6dd5..7b363ad7e1 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -2507,6 +2507,19 @@ class Model: ) +def test_inherited_dataclass_signature(): + @pydantic.dataclasses.dataclass + class A: + a: int + + @pydantic.dataclasses.dataclass + class B(A): + b: int + + assert str(inspect.signature(A)) == '(a: int) -> None' + assert str(inspect.signature(B)) == '(a: int, b: int) -> None' + + def test_dataclasses_with_slots_and_default(): @pydantic.dataclasses.dataclass(slots=True) class A: