Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add_overloaded_method_to_class helper to plugins/common.py #16038

Merged
merged 12 commits into from Sep 10, 2023
136 changes: 116 additions & 20 deletions mypy/plugins/common.py
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import NamedTuple

from mypy.argmap import map_actuals_to_formals
from mypy.fixup import TypeFixer
from mypy.nodes import (
Expand All @@ -16,9 +18,11 @@
JsonDict,
NameExpr,
Node,
OverloadedFuncDef,
PassStmt,
RefExpr,
SymbolTableNode,
TypeInfo,
Var,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
Expand Down Expand Up @@ -209,24 +213,99 @@ def add_method(
)


class MethodSpec(NamedTuple):
"""Represents a method signature to be added, except for `name`."""

args: list[Argument]
return_type: Type
self_type: Type | None = None
tvar_defs: list[TypeVarType] | None = None


def add_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
# MethodSpec items kept for backward compatibility:
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
tvar_def: list[TypeVarType] | TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
) -> FuncDef | Decorator:
"""Adds a new method to a class definition."""
_prepare_class_namespace(cls, name)

assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."
if tvar_def is not None and not isinstance(tvar_def, list):
tvar_def = [tvar_def]

func, sym = _add_method_by_spec(
api,
cls.info,
name,
MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def),
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
cls.info.names[name] = sym
cls.info.defn.defs.body.append(func)
return func


def add_overloaded_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
items: list[MethodSpec],
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> OverloadedFuncDef:
"""Adds a new overloaded method to a class definition."""
assert len(items) >= 2, "Overloads must contain at least two cases"

# Save old definition, if it exists.
_prepare_class_namespace(cls, name)

# Create function bodies for each passed method spec.
funcs: list[Decorator | FuncDef] = []
for item in items:
func, _sym = _add_method_by_spec(
api,
cls.info,
name=name,
spec=item,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
if isinstance(func, FuncDef):
var = Var(func.name, func.type)
var.set_line(func.line)
func.is_decorated = True
func.deco_line = func.line

deco = Decorator(func, [], var)
else:
deco = func
deco.is_overload = True
funcs.append(deco)

# Create the final OverloadedFuncDef node:
overload_def = OverloadedFuncDef(funcs)
overload_def.info = cls.info
overload_def.is_class = is_classmethod
overload_def.is_static = is_staticmethod
sym = SymbolTableNode(MDEF, overload_def)
sym.plugin_generated = True

cls.info.names[name] = sym
cls.info.defn.defs.body.append(overload_def)
return overload_def


def _prepare_class_namespace(cls: ClassDef, name: str) -> None:
info = cls.info
assert info

# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
Expand All @@ -235,6 +314,29 @@ def add_method_to_class(
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)

# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]


def _add_method_by_spec(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
info: TypeInfo,
name: str,
spec: MethodSpec,
*,
is_classmethod: bool,
is_staticmethod: bool,
) -> tuple[FuncDef | Decorator, SymbolTableNode]:
args, return_type, self_type, tvar_defs = spec

assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."

if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
Expand All @@ -258,8 +360,8 @@ def add_method_to_class(
arg_kinds.append(arg.kind)

signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
if tvar_defs:
signature.variables = tvar_defs

func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
Expand All @@ -269,13 +371,6 @@ def add_method_to_class(
func._fullname = info.fullname + "." + name
func.line = info.line

# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]

# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
if is_staticmethod:
func.is_decorated = True
Expand All @@ -286,12 +381,12 @@ def add_method_to_class(
dec = Decorator(func, [], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
else:
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
info.names[name] = sym
sym.plugin_generated = True
return dec, sym

info.defn.defs.body.append(func)
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
return func, sym


def add_attribute_to_class(
Expand All @@ -304,7 +399,7 @@ def add_attribute_to_class(
override_allow_incompatible: bool = False,
fullname: str | None = None,
is_classvar: bool = False,
) -> None:
) -> Var:
"""
Adds a new attribute to a class definition.
This currently only generates the symbol table entry and no corresponding AssignmentStatement
Expand Down Expand Up @@ -335,6 +430,7 @@ def add_attribute_to_class(
info.names[name] = SymbolTableNode(
MDEF, node, plugin_generated=True, no_serialize=no_serialize
)
return node


def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
Expand Down
24 changes: 23 additions & 1 deletion test-data/unit/check-custom-plugin.test
Expand Up @@ -1011,13 +1011,35 @@ class BaseAddMethod: pass
class MyClass(BaseAddMethod):
pass

my_class = MyClass()
reveal_type(MyClass.foo_classmethod) # N: Revealed type is "def ()"
reveal_type(MyClass.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"

my_class = MyClass()
reveal_type(my_class.foo_classmethod) # N: Revealed type is "def ()"
reveal_type(my_class.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_classmethod.py

[case testAddOverloadedMethodPlugin]
# flags: --config-file tmp/mypy.ini
class AddOverloadedMethod: pass

class MyClass(AddOverloadedMethod):
pass

reveal_type(MyClass.method) # N: Revealed type is "Overload(def (self: __main__.MyClass, arg: builtins.int) -> builtins.str, def (self: __main__.MyClass, arg: builtins.str) -> builtins.int)"
reveal_type(MyClass.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(MyClass.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"

my_class = MyClass()
reveal_type(my_class.method) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(my_class.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(my_class.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py

[case testCustomErrorCodePlugin]
# flags: --config-file tmp/mypy.ini --show-error-codes
def main() -> int:
Expand Down
38 changes: 38 additions & 0 deletions test-data/unit/check-incremental.test
Expand Up @@ -5935,6 +5935,44 @@ tmp/b.py:4: note: Revealed type is "def ()"
tmp/b.py:5: note: Revealed type is "def (builtins.int) -> builtins.str"
tmp/b.py:6: note: Revealed type is "def ()"
tmp/b.py:7: note: Revealed type is "def (builtins.int) -> builtins.str"

[case testIncrementalAddOverloadedMethodPlugin]
# flags: --config-file tmp/mypy.ini
import b

[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py

[file a.py]
class AddOverloadedMethod: pass

class MyClass(AddOverloadedMethod):
pass

[file b.py]
import a

[file b.py.2]
import a

reveal_type(a.MyClass.method)
reveal_type(a.MyClass.clsmethod)
reveal_type(a.MyClass.stmethod)

my_class = a.MyClass()
reveal_type(my_class.method)
reveal_type(my_class.clsmethod)
reveal_type(my_class.stmethod)
[rechecked b]
[out2]
tmp/b.py:3: note: Revealed type is "Overload(def (self: a.MyClass, arg: builtins.int) -> builtins.str, def (self: a.MyClass, arg: builtins.str) -> builtins.int)"
tmp/b.py:4: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:5: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:8: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:9: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:10: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"

[case testGenericNamedTupleSerialization]
import b
[file a.py]
Expand Down
6 changes: 4 additions & 2 deletions test-data/unit/deps.test
Expand Up @@ -1387,12 +1387,13 @@ class B(A):
<m.A.(abstract)> -> <m.B.__init__>, m
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
<m.A> -> m, m.A, m.B
<m.A[wildcard]> -> m
<m.B.__mypy-replace> -> m
<m.B.y> -> m
<m.B> -> m.B
<m.Z> -> m
Expand All @@ -1419,12 +1420,13 @@ class B(A):
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__match_args__> -> <m.B.__match_args__>
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
<m.A> -> m, m.A, m.B
<m.A[wildcard]> -> m
<m.B.__mypy-replace> -> m
<m.B.y> -> m
<m.B> -> m.B
<m.Z> -> m
Expand Down
41 changes: 41 additions & 0 deletions test-data/unit/plugins/add_overloaded_method.py
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import Callable

from mypy.nodes import ARG_POS, Argument, Var
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import MethodSpec, add_overloaded_method_to_class


class OverloadedMethodPlugin(Plugin):
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
if "AddOverloadedMethod" in fullname:
return add_overloaded_method_hook
return None


def add_overloaded_method_hook(ctx: ClassDefContext) -> None:
add_overloaded_method_to_class(ctx.api, ctx.cls, "method", _generate_method_specs(ctx))
add_overloaded_method_to_class(
ctx.api, ctx.cls, "clsmethod", _generate_method_specs(ctx), is_classmethod=True
)
add_overloaded_method_to_class(
ctx.api, ctx.cls, "stmethod", _generate_method_specs(ctx), is_staticmethod=True
)


def _generate_method_specs(ctx: ClassDefContext) -> list[MethodSpec]:
return [
MethodSpec(
args=[Argument(Var("arg"), ctx.api.named_type("builtins.int"), None, ARG_POS)],
return_type=ctx.api.named_type("builtins.str"),
),
MethodSpec(
args=[Argument(Var("arg"), ctx.api.named_type("builtins.str"), None, ARG_POS)],
return_type=ctx.api.named_type("builtins.int"),
),
]


def plugin(version: str) -> type[OverloadedMethodPlugin]:
return OverloadedMethodPlugin