diff --git a/pydantic/_internal/_constructor_signature_generators.py b/pydantic/_internal/_constructor_signature_generators.py new file mode 100644 index 0000000000..7ae28b0231 --- /dev/null +++ b/pydantic/_internal/_constructor_signature_generators.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from inspect import Parameter, Signature, signature +from typing import TYPE_CHECKING, Any, Callable + +from ._config import ConfigWrapper +from ._utils import is_valid_identifier + +if TYPE_CHECKING: + from ..fields import FieldInfo + + +def generate_pydantic_signature( + init: Callable[..., None], + fields: dict[str, FieldInfo], + config_wrapper: ConfigWrapper, + parameter_post_processor: Callable[[Parameter], Parameter] = lambda x: x, +) -> Signature: + """Generate signature for a pydantic BaseModel or dataclass. + + Args: + init: The class init. + fields: The model fields. + config_wrapper: The config wrapper instance. + parameter_post_processor: Optional additional processing for parameter + + Returns: + The dataclass/BaseModel subclass signature. + """ + from itertools import islice + + present_params = signature(init).parameters.values() + merged_params: dict[str, Parameter] = {} + var_kw = None + use_var_kw = False + + for param in islice(present_params, 1, None): # skip self arg + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if param.annotation == 'Any': + param = param.replace(annotation=Any) + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = parameter_post_processor(param) + + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config_wrapper.populate_by_name + for field_name, field in fields.items(): + # when alias is a str it should be used for signature generation + if isinstance(field.alias, str): + param_name = field.alias + else: + param_name = field_name + + if field_name in merged_params or param_name in merged_params: + continue + + if not is_valid_identifier(param_name): + if allow_names and is_valid_identifier(field_name): + param_name = field_name + else: + use_var_kw = True + continue + + kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} + merged_params[param_name] = parameter_post_processor( + Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs) + ) + + if config_wrapper.extra == 'allow': + use_var_kw = True + + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('self', Parameter.POSITIONAL_ONLY), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' + else: + # else start from var_kw + var_kw_name = var_kw.name + + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = parameter_post_processor(var_kw.replace(name=var_kw_name)) + + return Signature(parameters=list(merged_params.values()), return_annotation=None) diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index 2bc43e9665..065b004d81 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -6,7 +6,7 @@ import typing import warnings from functools import partial, wraps -from inspect import Parameter, Signature +from inspect import Parameter from typing import Any, Callable, ClassVar from pydantic_core import ( @@ -23,9 +23,9 @@ from ..plugin._schema_validator import create_schema_validator from ..warnings import PydanticDeprecatedSince20 from . import _config, _decorators, _typing_extra -from ._config import ConfigWrapper +from ._constructor_signature_generators import generate_pydantic_signature from ._fields import collect_dataclass_fields -from ._generate_schema import GenerateSchema, generate_pydantic_signature +from ._generate_schema import GenerateSchema from ._generics import get_standard_typevars_map from ._mock_val_ser import set_dataclass_mocks from ._schema_generation_shared import CallbackGetCoreSchemaHandler @@ -125,7 +125,12 @@ def complete_dataclass( ) # This needs to be called before we change the __init__ - sig = generate_dataclass_signature(cls, cls.__pydantic_fields__, config_wrapper) # type: ignore + sig = generate_pydantic_signature( + init=cls.__init__, + fields=cls.__pydantic_fields__, # type: ignore + config_wrapper=config_wrapper, + parameter_post_processor=process_param_defaults, + ) # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied. def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None: @@ -188,7 +193,7 @@ def validated_setattr(instance: Any, __field: str, __value: str) -> None: def process_param_defaults(param: Parameter) -> Parameter: - """Custom processing where the parameter default is of type FieldInfo + """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance. Args: param (Parameter): The parameter @@ -226,24 +231,6 @@ def process_param_defaults(param: Parameter) -> Parameter: return param -def generate_dataclass_signature( - cls: type[StandardDataclass], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper -) -> Signature: - """Generate signature for a pydantic dataclass. - - Args: - cls: The dataclass. - fields: The model fields. - config_wrapper: The config wrapper instance. - - Returns: - The dataclass signature. - """ - return generate_pydantic_signature( - init=cls.__init__, fields=fields, config_wrapper=config_wrapper, post_process_parameter=process_param_defaults - ) - - def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: """Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass. diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 84f8a637db..1b673f911e 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -78,7 +78,7 @@ CallbackGetCoreSchemaHandler, ) from ._typing_extra import is_finalvar -from ._utils import is_valid_identifier, lenient_issubclass +from ._utils import lenient_issubclass if TYPE_CHECKING: from ..fields import ComputedFieldInfo, FieldInfo @@ -2119,87 +2119,3 @@ def get(self) -> str | None: return self._stack[-1] else: return None - - -def generate_pydantic_signature( - init: Callable[..., None], - fields: dict[str, FieldInfo], - config_wrapper: ConfigWrapper, - post_process_parameter: Callable[[Parameter], Parameter] = lambda x: x, -) -> inspect.Signature: - """Generate signature for a pydantic class generated by inheriting from BaseModel or - using the dataclass annotation - - Args: - init: The class init. - fields: The model fields. - config_wrapper: The config wrapper instance. - post_process_parameter: Optional additional processing for parameter - - Returns: - The dataclass/BaseModel subclass signature. - """ - from itertools import islice - - present_params = signature(init).parameters.values() - merged_params: dict[str, Parameter] = {} - var_kw = None - use_var_kw = False - - for param in islice(present_params, 1, None): # skip self arg - # inspect does "clever" things to show annotations as strings because we have - # `from __future__ import annotations` in main, we don't want that - if param.annotation == 'Any': - param = param.replace(annotation=Any) - if param.kind is param.VAR_KEYWORD: - var_kw = param - continue - merged_params[param.name] = post_process_parameter(param) - - if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through - allow_names = config_wrapper.populate_by_name - for field_name, field in fields.items(): - # when alias is a str it should be used for signature generation - if isinstance(field.alias, str): - param_name = field.alias - else: - param_name = field_name - - if field_name in merged_params or param_name in merged_params: - continue - - if not is_valid_identifier(param_name): - if allow_names and is_valid_identifier(field_name): - param_name = field_name - else: - use_var_kw = True - continue - - kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} - merged_params[param_name] = post_process_parameter( - Parameter(param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs) - ) - - if config_wrapper.extra == 'allow': - use_var_kw = True - - if var_kw and use_var_kw: - # Make sure the parameter for extra kwargs - # does not have the same name as a field - default_model_signature = [ - ('self', Parameter.POSITIONAL_ONLY), - ('data', Parameter.VAR_KEYWORD), - ] - if [(p.name, p.kind) for p in present_params] == default_model_signature: - # if this is the standard model signature, use extra_data as the extra args name - var_kw_name = 'extra_data' - else: - # else start from var_kw - var_kw_name = var_kw.name - - # generate a name that's definitely unique - while var_kw_name in fields: - var_kw_name += '_' - merged_params[var_kw_name] = post_process_parameter(var_kw.replace(name=var_kw_name)) - - return inspect.Signature(parameters=list(merged_params.values()), return_annotation=None) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index c3f2a8313e..1d7fbf788a 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -18,9 +18,10 @@ from ..plugin._schema_validator import create_schema_validator from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20 from ._config import ConfigWrapper +from ._constructor_signature_generators import generate_pydantic_signature from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name -from ._generate_schema import GenerateSchema, generate_pydantic_signature +from ._generate_schema import GenerateSchema from ._generics import PydanticGenericMetadata, get_model_typevars_map from ._mock_val_ser import MockValSer, set_model_mocks from ._schema_generation_shared import CallbackGetCoreSchemaHandler @@ -29,8 +30,6 @@ from ._validate_call import ValidateCallWrapper if typing.TYPE_CHECKING: - from inspect import Signature - from ..fields import Field as PydanticModelField from ..fields import FieldInfo, ModelPrivateAttr from ..main import BaseModel @@ -536,27 +535,12 @@ def complete_model_class( # set __signature__ attr only for model class, but not for its instances cls.__signature__ = ClassAttribute( - '__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper) + '__signature__', + generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper), ) return True -def generate_model_signature( - init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper -) -> Signature: - """Generate signature for model based on its fields. - - Args: - init: The class init. - fields: The model fields. - config_wrapper: The config wrapper instance. - - Returns: - The model signature. - """ - return generate_pydantic_signature(init, fields, config_wrapper) - - class _PydanticWeakRef: """Wrapper for `weakref.ref` that enables `pickle` serialization. diff --git a/pydantic/fields.py b/pydantic/fields.py index cf9cad19a9..303c82426a 100644 --- a/pydantic/fields.py +++ b/pydantic/fields.py @@ -207,10 +207,8 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None: self.metadata = self._collect_metadata(kwargs) + annotation_metadata # type: ignore - @classmethod - def from_field( - cls, default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs] - ) -> typing_extensions.Self: + @staticmethod + def from_field(default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs]) -> FieldInfo: """Create a new `FieldInfo` object with the `Field` function. Args: @@ -235,10 +233,10 @@ class MyModel(pydantic.BaseModel): """ if 'annotation' in kwargs: raise TypeError('"annotation" is not permitted as a Field keyword argument') - return cls(default=default, **kwargs) + return FieldInfo(default=default, **kwargs) - @classmethod - def from_annotation(cls, annotation: type[Any]) -> FieldInfo: + @staticmethod + def from_annotation(annotation: type[Any]) -> FieldInfo: """Creates a `FieldInfo` instance from a bare annotation. Args: @@ -283,7 +281,7 @@ class MyModel(pydantic.BaseModel): if _typing_extra.is_finalvar(first_arg): final = True field_info_annotations = [a for a in extra_args if isinstance(a, FieldInfo)] - field_info = cls.merge_field_infos(*field_info_annotations, annotation=first_arg) + field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=first_arg) if field_info: new_field_info = copy(field_info) new_field_info.annotation = first_arg @@ -297,10 +295,10 @@ class MyModel(pydantic.BaseModel): new_field_info.metadata = metadata return new_field_info - return cls(annotation=annotation, frozen=final or None) + return FieldInfo(annotation=annotation, frozen=final or None) - @classmethod - def from_annotated_attribute(cls, annotation: type[Any], default: Any) -> FieldInfo: + @staticmethod + def from_annotated_attribute(annotation: type[Any], default: Any) -> FieldInfo: """Create `FieldInfo` from an annotation with a default value. Args: @@ -329,11 +327,11 @@ class MyModel(pydantic.BaseModel): if annotation is not typing_extensions.Final: annotation = typing_extensions.get_args(annotation)[0] - if isinstance(default, cls): - default.annotation, annotation_metadata = cls._extract_metadata(annotation) + if isinstance(default, FieldInfo): + default.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) default.metadata += annotation_metadata default = default.merge_field_infos( - *[x for x in annotation_metadata if isinstance(x, cls)], default, annotation=default.annotation + *[x for x in annotation_metadata if isinstance(x, FieldInfo)], default, annotation=default.annotation ) default.frozen = final or default.frozen return default @@ -345,11 +343,11 @@ class MyModel(pydantic.BaseModel): elif isinstance(annotation, dataclasses.InitVar): init_var = True annotation = annotation.type - pydantic_field = cls._from_dataclass_field(default) - pydantic_field.annotation, annotation_metadata = cls._extract_metadata(annotation) + pydantic_field = FieldInfo._from_dataclass_field(default) + pydantic_field.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) pydantic_field.metadata += annotation_metadata pydantic_field = pydantic_field.merge_field_infos( - *[x for x in annotation_metadata if isinstance(x, cls)], + *[x for x in annotation_metadata if isinstance(x, FieldInfo)], pydantic_field, annotation=pydantic_field.annotation, ) @@ -361,7 +359,7 @@ class MyModel(pydantic.BaseModel): if _typing_extra.is_annotated(annotation): first_arg, *extra_args = typing_extensions.get_args(annotation) field_infos = [a for a in extra_args if isinstance(a, FieldInfo)] - field_info = cls.merge_field_infos(*field_infos, annotation=first_arg, default=default) + field_info = FieldInfo.merge_field_infos(*field_infos, annotation=first_arg, default=default) metadata: list[Any] = [] for a in extra_args: if not isinstance(a, FieldInfo): @@ -371,7 +369,7 @@ class MyModel(pydantic.BaseModel): field_info.metadata = metadata return field_info - return cls(annotation=annotation, default=default, frozen=final or None) + return FieldInfo(annotation=annotation, default=default, frozen=final or None) @staticmethod def merge_field_infos(*field_infos: FieldInfo, **overrides: Any) -> FieldInfo: @@ -407,8 +405,8 @@ def merge_field_infos(*field_infos: FieldInfo, **overrides: Any) -> FieldInfo: field_info.metadata = list(metadata.values()) return field_info - @classmethod - def _from_dataclass_field(cls, dc_field: DataclassField[Any]) -> typing_extensions.Self: + @staticmethod + def _from_dataclass_field(dc_field: DataclassField[Any]) -> FieldInfo: """Return a new `FieldInfo` instance from a `dataclasses.Field` instance. Args: @@ -433,8 +431,8 @@ def _from_dataclass_field(cls, dc_field: DataclassField[Any]) -> typing_extensio dc_field_metadata = {k: v for k, v in dc_field.metadata.items() if k in _FIELD_ARG_NAMES} return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) - @classmethod - def _extract_metadata(cls, annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]: + @staticmethod + def _extract_metadata(annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]: """Tries to extract metadata/constraints from an annotation if it uses `Annotated`. Args: @@ -450,8 +448,8 @@ def _extract_metadata(cls, annotation: type[Any] | None) -> tuple[type[Any] | No return annotation, [] - @classmethod - def _collect_metadata(cls, kwargs: dict[str, Any]) -> list[Any]: + @staticmethod + def _collect_metadata(kwargs: dict[str, Any]) -> list[Any]: """Collect annotations from kwargs. The return type is actually `annotated_types.BaseMetadata | PydanticMetadata`, @@ -468,7 +466,7 @@ def _collect_metadata(cls, kwargs: dict[str, Any]) -> list[Any]: general_metadata = {} for key, value in list(kwargs.items()): try: - marker = cls.metadata_lookup[key] + marker = FieldInfo.metadata_lookup[key] except KeyError: continue diff --git a/pydantic/mypy.py b/pydantic/mypy.py index eb2127e253..b5a869236b 100644 --- a/pydantic/mypy.py +++ b/pydantic/mypy.py @@ -401,10 +401,12 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: class PydanticModelClassVar: - """Class vars are stored to be ignored by subclasses. + """Based on mypy.plugins.dataclasses.DataclassAttribute. + + ClassVars are ignored by subclasses. Attributes: - name: the class var name + name: the ClassVar name """ def __init__(self, name): @@ -464,8 +466,8 @@ def transform(self) -> bool: info = self._cls.info is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1]) config = self.collect_config() - fields, classvars = self.collect_fields_and_classvar(config, is_root_model) - if fields is None or classvars is None: + fields, class_vars = self.collect_fields_and_class_vars(config, is_root_model) + if fields is None or class_vars is None: # Some definitions are not ready. We need another pass. return False for field in fields: @@ -485,7 +487,7 @@ def transform(self) -> bool: info.metadata[METADATA_KEY] = { 'fields': {field.name: field.serialize() for field in fields}, - 'classvars': {classvar.name: classvar.serialize() for classvar in classvars}, + 'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars}, 'config': config.get_values_dict(), } @@ -594,13 +596,13 @@ def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity) config.setdefault(name, value) return config - def collect_fields_and_classvar( + def collect_fields_and_class_vars( self, model_config: ModelConfigData, is_root_model: bool ) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]: """Collects the fields for the model, accounting for parent classes.""" cls = self._cls - # First, collect fields and classvars belonging to any class in the MRO, ignoring duplicates. + # First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates. # # We iterate through the MRO in reverse because attrs defined in the parent must appear # earlier in the attributes list than attrs defined in the child. See: @@ -610,7 +612,7 @@ def collect_fields_and_classvar( # in the parent. We can implement this via a dict without disrupting the attr order # because dicts preserve insertion order in Python 3.7+. found_fields: dict[str, PydanticModelField] = {} - found_classvars: dict[str, PydanticModelClassVar] = {} + found_class_vars: dict[str, PydanticModelClassVar] = {} for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object # if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata: # # We haven't processed the base class yet. Need another pass. @@ -637,15 +639,15 @@ def collect_fields_and_classvar( 'BaseModel field may only be overridden by another field', sym_node.node, ) - # Collect classvars - for name, data in info.metadata[METADATA_KEY]['classvars'].items(): - found_classvars[name] = PydanticModelClassVar.deserialize(data) + # Collect ClassVars + for name, data in info.metadata[METADATA_KEY]['class_vars'].items(): + found_class_vars[name] = PydanticModelClassVar.deserialize(data) - # Second, collect fields and classvars belonging to the current class. + # Second, collect fields and ClassVars belonging to the current class. current_field_names: set[str] = set() - current_classvars_names: set[str] = set() + current_class_vars_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): - maybe_field = self.collect_field_and_classvars_from_stmt(stmt, model_config, found_classvars) + maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars) if isinstance(maybe_field, PydanticModelField): lhs = stmt.lvalues[0] if is_root_model and lhs.name != 'root': @@ -655,10 +657,10 @@ def collect_fields_and_classvar( found_fields[lhs.name] = maybe_field elif isinstance(maybe_field, PydanticModelClassVar): lhs = stmt.lvalues[0] - current_classvars_names.add(lhs.name) - found_classvars[lhs.name] = maybe_field + current_class_vars_names.add(lhs.name) + found_class_vars[lhs.name] = maybe_field - return list(found_fields.values()), list(found_classvars.values()) + return list(found_fields.values()), list(found_class_vars.values()) def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: for body in stmt.body: @@ -674,14 +676,15 @@ def _get_assignment_statements_from_block(self, block: Block) -> Iterator[Assign elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) - def collect_field_and_classvars_from_stmt( # noqa C901 - self, stmt: AssignmentStmt, model_config: ModelConfigData, classvars: dict[str, PydanticModelClassVar] + def collect_field_or_class_var_from_stmt( # noqa C901 + self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar] ) -> PydanticModelField | PydanticModelClassVar | None: """Get pydantic model field from statement. Args: stmt: The statement. model_config: Configuration settings for the model. + class_vars: ClassVars already known to be defined on the model. Returns: A pydantic model field if it could find the field in statement. Otherwise, `None`. @@ -704,7 +707,7 @@ def collect_field_and_classvars_from_stmt( # noqa C901 # Eventually, we may want to attempt to respect model_config['ignored_types'] return None - if lhs.name in classvars: + if lhs.name in class_vars: # Class vars are not fields and are not required to be annotated return None