Skip to content

Commit

Permalink
[cherry-pick] Fix mypy plugin for 1.1.0 (#5077) (#5111)
Browse files Browse the repository at this point in the history
* Fix mypy plugin for 1.1.0 (#5077)

* Fix mypy plugin for 1.1.0
* Code review
* Add version key to plugin data

(cherry picked from commit 6267ae3)

* Change file name

* Add the changes from #5120

* Update changes file

* Remove additional unneeded dataclass import (from #5120)

---------

Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
cdce8p and dmontagu committed Mar 8, 2023
1 parent 7f3b754 commit 9d0edbe
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
1 change: 1 addition & 0 deletions changes/5111-cdce8p.md
@@ -0,0 +1 @@
Fix mypy plugin for v1.1.1, and fix `dataclass_transform` decorator for pydantic dataclasses
18 changes: 6 additions & 12 deletions pydantic/dataclasses.py
Expand Up @@ -32,6 +32,7 @@ class M:
validation without altering default `M` behaviour.
"""
import copy
import dataclasses
import sys
from contextlib import contextmanager
from functools import wraps
Expand Down Expand Up @@ -93,7 +94,7 @@ def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':

if sys.version_info >= (3, 10):

@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
*,
Expand All @@ -110,7 +111,7 @@ def dataclass(
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...

@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
_cls: Type[_T],
Expand All @@ -130,7 +131,7 @@ def dataclass(

else:

@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
*,
Expand All @@ -146,7 +147,7 @@ def dataclass(
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...

@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
_cls: Type[_T],
Expand All @@ -164,7 +165,7 @@ def dataclass(
...


@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
def dataclass(
_cls: Optional[Type[_T]] = None,
*,
Expand All @@ -188,8 +189,6 @@ def dataclass(
the_config = get_config(config)

def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
import dataclasses

should_use_proxy = (
use_proxy
if use_proxy is not None
Expand Down Expand Up @@ -328,7 +327,6 @@ def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
if hasattr(self, '__post_init_post_parse__'):
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
# public method `dataclasses.fields`
import dataclasses

# get all initvars and their default values
initvars_and_values: Dict[str, Any] = {}
Expand Down Expand Up @@ -377,8 +375,6 @@ def create_pydantic_model_from_dataclass(
config: Type[Any] = BaseConfig,
dc_cls_doc: Optional[str] = None,
) -> Type['BaseModel']:
import dataclasses

field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(dc_cls):
default: Any = Undefined
Expand Down Expand Up @@ -466,8 +462,6 @@ class B(A):
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
"""
import dataclasses

return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_model__')
Expand Down
3 changes: 1 addition & 2 deletions pydantic/main.py
Expand Up @@ -33,7 +33,6 @@
from .fields import (
MAPPING_LIKE_SHAPES,
Field,
FieldInfo,
ModelField,
ModelPrivateAttr,
PrivateAttr,
Expand Down Expand Up @@ -118,7 +117,7 @@ def hash_function(self_: Any) -> int:
_is_base_model_class_defined = False


@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ModelMetaclass(ABCMeta):
@no_type_check # noqa C901
def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
Expand Down
32 changes: 31 additions & 1 deletion pydantic/mypy.py
Expand Up @@ -76,6 +76,7 @@
METADATA_KEY = 'pydantic-mypy-metadata'
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings'
MODEL_METACLASS_FULLNAME = 'pydantic.main.ModelMetaclass'
FIELD_FULLNAME = 'pydantic.fields.Field'
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'

Expand All @@ -87,6 +88,9 @@ def parse_mypy_version(version: str) -> Tuple[int, ...]:
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'

# Increment version if plugin changes and mypy caches should be invalidated
PLUGIN_VERSION = 1


def plugin(version: str) -> 'TypingType[Plugin]':
"""
Expand All @@ -102,6 +106,7 @@ class PydanticPlugin(Plugin):
def __init__(self, options: Options) -> None:
self.plugin_config = PydanticPluginConfig(options)
self._plugin_data = self.plugin_config.to_data()
self._plugin_data['version'] = PLUGIN_VERSION
super().__init__(options)

def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
Expand All @@ -112,6 +117,11 @@ def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefCont
return self._pydantic_model_class_maker_callback
return None

def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
if fullname == MODEL_METACLASS_FULLNAME:
return self._pydantic_model_metaclass_marker_callback
return None

def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
sym = self.lookup_fully_qualified(fullname)
if sym and sym.fullname == FIELD_FULLNAME:
Expand Down Expand Up @@ -139,6 +149,19 @@ def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = PydanticModelTransformer(ctx, self.plugin_config)
transformer.transform()

def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
Let the plugin handle it. This behavior can be disabled
if 'debug_dataclass_transform' is set to True', for testing purposes.
"""
if self.plugin_config.debug_dataclass_transform:
return
info_metaclass = ctx.cls.info.declared_metaclass
assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined]

def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
"""
Extract the type of the `default` argument from the Field function, and use it as the return type.
Expand Down Expand Up @@ -194,11 +217,18 @@ def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':


class PydanticPluginConfig:
__slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields')
__slots__ = (
'init_forbid_extra',
'init_typed',
'warn_required_dynamic_aliases',
'warn_untyped_fields',
'debug_dataclass_transform',
)
init_forbid_extra: bool
init_typed: bool
warn_required_dynamic_aliases: bool
warn_untyped_fields: bool
debug_dataclass_transform: bool # undocumented

def __init__(self, options: Options) -> None:
if options.config_file is None: # pragma: no cover
Expand Down

0 comments on commit 9d0edbe

Please sign in to comment.