diff --git a/pydantic/mypy.py b/pydantic/mypy.py index c2eb919472..e91aa7e613 100644 --- a/pydantic/mypy.py +++ b/pydantic/mypy.py @@ -89,6 +89,7 @@ METADATA_KEY = 'pydantic-mypy-metadata' BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings' +ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel' MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass' FIELD_FULLNAME = 'pydantic.fields.Field' DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' @@ -430,8 +431,9 @@ def transform(self) -> bool: * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses """ info = self._cls.info + is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1]) config = self.collect_config() - fields = self.collect_fields(config) + fields = self.collect_fields(config, is_root_model) if fields is None: # Some definitions are not ready. We need another pass. return False @@ -440,7 +442,7 @@ def transform(self) -> bool: return False is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1]) - self.add_initializer(fields, config, is_settings) + self.add_initializer(fields, config, is_settings, is_root_model) self.add_model_construct_method(fields, config, is_settings) self.set_frozen(fields, frozen=config.frozen is True) @@ -556,7 +558,7 @@ def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity) config.setdefault(name, value) return config - def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None: + def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) -> list[PydanticModelField] | None: """Collects the fields for the model, accounting for parent classes.""" cls = self._cls @@ -603,8 +605,11 @@ def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelFie maybe_field = self.collect_field_from_stmt(stmt, model_config) if maybe_field is not None: lhs = stmt.lvalues[0] - current_field_names.add(lhs.name) - found_fields[lhs.name] = maybe_field + if is_root_model and lhs.name != 'root': + error_extra_fields_on_root_model(self._api, stmt) + else: + current_field_names.add(lhs.name) + found_fields[lhs.name] = maybe_field return list(found_fields.values()) @@ -780,7 +785,9 @@ def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, conte return default - def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None: + def add_initializer( + self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool + ) -> None: """Adds a fields-aware `__init__` method to the class. The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. @@ -799,6 +806,10 @@ def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigD use_alias=use_alias, is_settings=is_settings, ) + if is_root_model: + # convert root argument to positional argument + args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT + if is_settings: base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node if '__init__' in base_settings_node.names: @@ -1048,6 +1059,7 @@ def setdefault(self, key: str, value: Any) -> None: ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') +ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic') def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None: @@ -1084,6 +1096,11 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) +def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None: + """Emits an error when there is more than just a root field defined for a subclass of RootModel.""" + api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL) + + def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: """Emits an error when `Field` has both `default` and `default_factory` together.""" api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) diff --git a/tests/mypy/modules/root_models.py b/tests/mypy/modules/root_models.py new file mode 100644 index 0000000000..844088d488 --- /dev/null +++ b/tests/mypy/modules/root_models.py @@ -0,0 +1,23 @@ +from typing import List + +from pydantic import RootModel + + +class Pets1(RootModel[List[str]]): + pass + + +Pets2 = RootModel[List[str]] + + +class Pets3(RootModel): + root: List[str] + + +pets1 = Pets1(['dog', 'cat']) +pets2 = Pets2(['dog', 'cat']) +pets3 = Pets3(['dog', 'cat']) + + +class Pets4(RootModel[List[str]]): + pets: List[str] diff --git a/tests/mypy/outputs/1.0.1/mypy-plugin_ini/root_models.py b/tests/mypy/outputs/1.0.1/mypy-plugin_ini/root_models.py new file mode 100644 index 0000000000..24d5f5f06d --- /dev/null +++ b/tests/mypy/outputs/1.0.1/mypy-plugin_ini/root_models.py @@ -0,0 +1,25 @@ +from typing import List + +from pydantic import RootModel + + +class Pets1(RootModel[List[str]]): + pass + + +Pets2 = RootModel[List[str]] + + +class Pets3(RootModel): +# MYPY: error: Missing type parameters for generic type "RootModel" [type-arg] + root: List[str] + + +pets1 = Pets1(['dog', 'cat']) +pets2 = Pets2(['dog', 'cat']) +pets3 = Pets3(['dog', 'cat']) + + +class Pets4(RootModel[List[str]]): + pets: List[str] +# MYPY: error: Only `root` is allowed as a field of a `RootModel` [pydantic-field] diff --git a/tests/mypy/test_mypy.py b/tests/mypy/test_mypy.py index 782caaed1b..435b6c846a 100644 --- a/tests/mypy/test_mypy.py +++ b/tests/mypy/test_mypy.py @@ -98,6 +98,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]: + [ ('mypy-plugin.ini', 'custom_constructor.py'), ('mypy-plugin.ini', 'generics.py'), + ('mypy-plugin.ini', 'root_models.py'), ('mypy-plugin-strict.ini', 'plugin_default_factory.py'), ('mypy-plugin-strict-no-any.ini', 'dataclass_no_any.py'), ('mypy-plugin-very-strict.ini', 'metaclass_args.py'),