From a7c0ce791d8d9164275600932c56980db61ab709 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 24 Feb 2023 17:36:21 +0100 Subject: [PATCH] Fix mypy plugin for 1.1.0 (#5077) * Fix mypy plugin for 1.1.0 * Code review * Add version key to plugin data --- changes/5077-cdce8p.md | 1 + pydantic/mypy.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 changes/5077-cdce8p.md diff --git a/changes/5077-cdce8p.md b/changes/5077-cdce8p.md new file mode 100644 index 0000000000..6bae549c5a --- /dev/null +++ b/changes/5077-cdce8p.md @@ -0,0 +1 @@ +Fix mypy plugin for v1.1.0 diff --git a/pydantic/mypy.py b/pydantic/mypy.py index 02a0510a0f..83b13fb5e0 100644 --- a/pydantic/mypy.py +++ b/pydantic/mypy.py @@ -75,7 +75,7 @@ CONFIGFILE_KEY = 'pydantic-mypy' METADATA_KEY = 'pydantic-mypy-metadata' BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' -BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings' +MODEL_METACLASS_FULLNAME = 'pydantic.main.ModelMetaclass' FIELD_FULLNAME = 'pydantic.fields.Field' DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' @@ -87,6 +87,9 @@ def parse_mypy_version(version: str) -> Tuple[int, ...]: MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' +# Increment version if plugin changes and mypy caches should be invalidated +PLUGIN_VERSION = 1 + def plugin(version: str) -> 'TypingType[Plugin]': """ @@ -102,6 +105,7 @@ class PydanticPlugin(Plugin): def __init__(self, options: Options) -> None: self.plugin_config = PydanticPluginConfig(options) self._plugin_data = self.plugin_config.to_data() + self._plugin_data['version'] = PLUGIN_VERSION super().__init__(options) def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]': @@ -112,6 +116,11 @@ def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefCont return self._pydantic_model_class_maker_callback return None + def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + if fullname == MODEL_METACLASS_FULLNAME: + return self._pydantic_model_metaclass_marker_callback + return None + def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]': sym = self.lookup_fully_qualified(fullname) if sym and sym.fullname == FIELD_FULLNAME: @@ -139,6 +148,19 @@ def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: transformer = PydanticModelTransformer(ctx, self.plugin_config) transformer.transform() + def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: + """Reset dataclass_transform_spec attribute of ModelMetaclass. + + Let the plugin handle it. This behavior can be disabled + if 'debug_dataclass_transform' is set to True', for testing purposes. + """ + if self.plugin_config.debug_dataclass_transform: + return + info_metaclass = ctx.cls.info.declared_metaclass + assert info_metaclass, "callback not passed from 'get_metaclass_hook'" + if getattr(info_metaclass.type, 'dataclass_transform_spec', None): + info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined] + def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type': """ Extract the type of the `default` argument from the Field function, and use it as the return type. @@ -194,11 +216,18 @@ def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type': class PydanticPluginConfig: - __slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields') + __slots__ = ( + 'init_forbid_extra', + 'init_typed', + 'warn_required_dynamic_aliases', + 'warn_untyped_fields', + 'debug_dataclass_transform', + ) init_forbid_extra: bool init_typed: bool warn_required_dynamic_aliases: bool warn_untyped_fields: bool + debug_dataclass_transform: bool # undocumented def __init__(self, options: Options) -> None: if options.config_file is None: # pragma: no cover