Skip to content

Commit

Permalink
Fix bug with mypy plugin's handling of covariant typevar fields (#9606)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Jun 7, 2024
1 parent 4dfde6f commit 8333bd5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
38 changes: 33 additions & 5 deletions pydantic/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ARG_OPT,
ARG_POS,
ARG_STAR2,
INVARIANT,
MDEF,
Argument,
AssignmentStmt,
Expand Down Expand Up @@ -343,9 +344,10 @@ def to_argument(
force_optional: bool,
use_alias: bool,
api: SemanticAnalyzerPluginInterface,
force_typevars_invariant: bool,
) -> Argument:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument."""
variable = self.to_var(current_info, api, use_alias)
variable = self.to_var(current_info, api, use_alias, force_typevars_invariant)
type_annotation = self.expand_type(current_info, api) if typed else AnyType(TypeOfAny.explicit)
return Argument(
variable=variable,
Expand All @@ -354,26 +356,49 @@ def to_argument(
kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED,
)

def expand_type(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface) -> Type | None:
def expand_type(
self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, force_typevars_invariant: bool = False
) -> Type | None:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type."""
# The getattr in the next line is used to prevent errors in legacy versions of mypy without this attribute
if force_typevars_invariant:
# In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter"
# To prevent that, we add an option to replace typevars with invariant ones while building certain
# method signatures (in particular, `__init__`). There may be a better way to do this, if this causes
# us problems in the future, we should look into why the dataclasses plugin doesn't have this issue.
if isinstance(self.type, TypeVarType):
modified_type = self.type.copy_modified()
modified_type.variance = INVARIANT
self.type = modified_type

if self.type is not None and getattr(self.info, 'self_type', None) is not None:
# In general, it is not safe to call `expand_type()` during semantic analyzis,
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
with state.strict_optional_set(api.options.strict_optional):
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
filled_with_typevars = fill_typevars(current_info)
if force_typevars_invariant:
for arg in filled_with_typevars.args:
if isinstance(arg, TypeVarType):
arg.variance = INVARIANT
return expand_type(self.type, {self.info.self_type.id: filled_with_typevars})
return self.type

def to_var(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, use_alias: bool) -> Var:
def to_var(
self,
current_info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
use_alias: bool,
force_typevars_invariant: bool = False,
) -> Var:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_var."""
if use_alias and self.alias is not None:
name = self.alias
else:
name = self.name

return Var(name, self.expand_type(current_info, api))
return Var(name, self.expand_type(current_info, api, force_typevars_invariant))

def serialize(self) -> JsonDict:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
Expand Down Expand Up @@ -858,6 +883,7 @@ def add_initializer(
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
force_typevars_invariant=True,
)

if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
Expand Down Expand Up @@ -1037,6 +1063,7 @@ def get_field_arguments(
use_alias: bool,
requires_dynamic_aliases: bool,
is_settings: bool,
force_typevars_invariant: bool = False,
) -> list[Argument]:
"""Helper function used during the construction of the `__init__` and `model_construct` method signatures.
Expand All @@ -1050,6 +1077,7 @@ def get_field_arguments(
force_optional=requires_dynamic_aliases or is_settings,
use_alias=use_alias,
api=self._api,
force_typevars_invariant=force_typevars_invariant,
)
for field in fields
if not (use_alias and field.has_dynamic_alias)
Expand Down
12 changes: 12 additions & 0 deletions tests/mypy/modules/covariant_typevar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Generic, TypeVar

from pydantic import BaseModel

T = TypeVar("T", covariant=True)


class Foo(BaseModel, Generic[T]):
value: T


class Bar(Foo[T]): ...
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
+ [
('mypy-plugin.ini', 'custom_constructor.py'),
('mypy-plugin.ini', 'config_conditional_extra.py'),
('mypy-plugin.ini', 'covariant_typevar.py'),
('mypy-plugin.ini', 'plugin_optional_inheritance.py'),
('mypy-plugin.ini', 'generics.py'),
('mypy-plugin.ini', 'root_models.py'),
Expand Down

0 comments on commit 8333bd5

Please sign in to comment.