Skip to content

Commit

Permalink
Add add_overloaded_method_to_class helper to plugins/common.py (#…
Browse files Browse the repository at this point in the history
…16038)

There are several changes:

1. `add_overloaded_method_to_class` itself. It is very useful for plugin
authors, because right now it is quite easy to add a regular method, but
it is very hard to add a method with `@overload`s. I don't think that
user must face all the chalenges that I've covered in this method.
Moreover, it is quite easy even for experienced developers to forget
some flags / props / etc (I am pretty sure that I might forgot something
in the implementation)
2. `add_overloaded_method_to_class` and `add_method_to_class` now return
added nodes, it is also helpful if you want to do something with this
node in your plugin after it is created
3. I've refactored how `add_method_to_class` works and reused its parts
in the new method as well
4. `tvar_def` in `add_method_to_class` can now accept a list of type
vars, not just one

Notice that `add_method_to_class` is unchanged from the user's POV, it
should continue to work as before.

Tests are also updated to check that our overloads are correct.

Things to do later (in the next PRs / releases):
1. We can possibly add `is_final` param to methods as well
2. We can also support `@property` in a separate method at some point

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sobolevn and pre-commit-ci[bot] committed Sep 10, 2023
1 parent ed18fea commit 9a35360
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 23 deletions.
136 changes: 116 additions & 20 deletions mypy/plugins/common.py
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import NamedTuple

from mypy.argmap import map_actuals_to_formals
from mypy.fixup import TypeFixer
from mypy.nodes import (
Expand All @@ -16,9 +18,11 @@
JsonDict,
NameExpr,
Node,
OverloadedFuncDef,
PassStmt,
RefExpr,
SymbolTableNode,
TypeInfo,
Var,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
Expand Down Expand Up @@ -209,24 +213,99 @@ def add_method(
)


class MethodSpec(NamedTuple):
"""Represents a method signature to be added, except for `name`."""

args: list[Argument]
return_type: Type
self_type: Type | None = None
tvar_defs: list[TypeVarType] | None = None


def add_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
# MethodSpec items kept for backward compatibility:
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
tvar_def: list[TypeVarType] | TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
) -> FuncDef | Decorator:
"""Adds a new method to a class definition."""
_prepare_class_namespace(cls, name)

assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."
if tvar_def is not None and not isinstance(tvar_def, list):
tvar_def = [tvar_def]

func, sym = _add_method_by_spec(
api,
cls.info,
name,
MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def),
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
cls.info.names[name] = sym
cls.info.defn.defs.body.append(func)
return func


def add_overloaded_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
items: list[MethodSpec],
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> OverloadedFuncDef:
"""Adds a new overloaded method to a class definition."""
assert len(items) >= 2, "Overloads must contain at least two cases"

# Save old definition, if it exists.
_prepare_class_namespace(cls, name)

# Create function bodies for each passed method spec.
funcs: list[Decorator | FuncDef] = []
for item in items:
func, _sym = _add_method_by_spec(
api,
cls.info,
name=name,
spec=item,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
if isinstance(func, FuncDef):
var = Var(func.name, func.type)
var.set_line(func.line)
func.is_decorated = True
func.deco_line = func.line

deco = Decorator(func, [], var)
else:
deco = func
deco.is_overload = True
funcs.append(deco)

# Create the final OverloadedFuncDef node:
overload_def = OverloadedFuncDef(funcs)
overload_def.info = cls.info
overload_def.is_class = is_classmethod
overload_def.is_static = is_staticmethod
sym = SymbolTableNode(MDEF, overload_def)
sym.plugin_generated = True

cls.info.names[name] = sym
cls.info.defn.defs.body.append(overload_def)
return overload_def


def _prepare_class_namespace(cls: ClassDef, name: str) -> None:
info = cls.info
assert info

# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
Expand All @@ -235,6 +314,29 @@ def add_method_to_class(
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)

# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]


def _add_method_by_spec(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
info: TypeInfo,
name: str,
spec: MethodSpec,
*,
is_classmethod: bool,
is_staticmethod: bool,
) -> tuple[FuncDef | Decorator, SymbolTableNode]:
args, return_type, self_type, tvar_defs = spec

assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."

if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
Expand All @@ -258,8 +360,8 @@ def add_method_to_class(
arg_kinds.append(arg.kind)

signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
if tvar_defs:
signature.variables = tvar_defs

func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
Expand All @@ -269,13 +371,6 @@ def add_method_to_class(
func._fullname = info.fullname + "." + name
func.line = info.line

# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]

# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
if is_staticmethod:
func.is_decorated = True
Expand All @@ -286,12 +381,12 @@ def add_method_to_class(
dec = Decorator(func, [], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
else:
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
info.names[name] = sym
sym.plugin_generated = True
return dec, sym

info.defn.defs.body.append(func)
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
return func, sym


def add_attribute_to_class(
Expand All @@ -304,7 +399,7 @@ def add_attribute_to_class(
override_allow_incompatible: bool = False,
fullname: str | None = None,
is_classvar: bool = False,
) -> None:
) -> Var:
"""
Adds a new attribute to a class definition.
This currently only generates the symbol table entry and no corresponding AssignmentStatement
Expand Down Expand Up @@ -335,6 +430,7 @@ def add_attribute_to_class(
info.names[name] = SymbolTableNode(
MDEF, node, plugin_generated=True, no_serialize=no_serialize
)
return node


def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
Expand Down
24 changes: 23 additions & 1 deletion test-data/unit/check-custom-plugin.test
Expand Up @@ -1011,13 +1011,35 @@ class BaseAddMethod: pass
class MyClass(BaseAddMethod):
pass

my_class = MyClass()
reveal_type(MyClass.foo_classmethod) # N: Revealed type is "def ()"
reveal_type(MyClass.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"

my_class = MyClass()
reveal_type(my_class.foo_classmethod) # N: Revealed type is "def ()"
reveal_type(my_class.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_classmethod.py

[case testAddOverloadedMethodPlugin]
# flags: --config-file tmp/mypy.ini
class AddOverloadedMethod: pass

class MyClass(AddOverloadedMethod):
pass

reveal_type(MyClass.method) # N: Revealed type is "Overload(def (self: __main__.MyClass, arg: builtins.int) -> builtins.str, def (self: __main__.MyClass, arg: builtins.str) -> builtins.int)"
reveal_type(MyClass.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(MyClass.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"

my_class = MyClass()
reveal_type(my_class.method) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(my_class.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
reveal_type(my_class.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py

[case testCustomErrorCodePlugin]
# flags: --config-file tmp/mypy.ini --show-error-codes
def main() -> int:
Expand Down
38 changes: 38 additions & 0 deletions test-data/unit/check-incremental.test
Expand Up @@ -5935,6 +5935,44 @@ tmp/b.py:4: note: Revealed type is "def ()"
tmp/b.py:5: note: Revealed type is "def (builtins.int) -> builtins.str"
tmp/b.py:6: note: Revealed type is "def ()"
tmp/b.py:7: note: Revealed type is "def (builtins.int) -> builtins.str"

[case testIncrementalAddOverloadedMethodPlugin]
# flags: --config-file tmp/mypy.ini
import b

[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_overloaded_method.py

[file a.py]
class AddOverloadedMethod: pass

class MyClass(AddOverloadedMethod):
pass

[file b.py]
import a

[file b.py.2]
import a

reveal_type(a.MyClass.method)
reveal_type(a.MyClass.clsmethod)
reveal_type(a.MyClass.stmethod)

my_class = a.MyClass()
reveal_type(my_class.method)
reveal_type(my_class.clsmethod)
reveal_type(my_class.stmethod)
[rechecked b]
[out2]
tmp/b.py:3: note: Revealed type is "Overload(def (self: a.MyClass, arg: builtins.int) -> builtins.str, def (self: a.MyClass, arg: builtins.str) -> builtins.int)"
tmp/b.py:4: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:5: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:8: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:9: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"
tmp/b.py:10: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)"

[case testGenericNamedTupleSerialization]
import b
[file a.py]
Expand Down
6 changes: 4 additions & 2 deletions test-data/unit/deps.test
Expand Up @@ -1387,12 +1387,13 @@ class B(A):
<m.A.(abstract)> -> <m.B.__init__>, m
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
<m.A> -> m, m.A, m.B
<m.A[wildcard]> -> m
<m.B.__mypy-replace> -> m
<m.B.y> -> m
<m.B> -> m.B
<m.Z> -> m
Expand All @@ -1419,12 +1420,13 @@ class B(A):
<m.A.__dataclass_fields__> -> <m.B.__dataclass_fields__>
<m.A.__init__> -> <m.B.__init__>, m.B.__init__
<m.A.__match_args__> -> <m.B.__match_args__>
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m.B.__mypy-replace
<m.A.__mypy-replace> -> <m.B.__mypy-replace>, m, m.B.__mypy-replace
<m.A.__new__> -> <m.B.__new__>
<m.A.x> -> <m.B.x>
<m.A.y> -> <m.B.y>
<m.A> -> m, m.A, m.B
<m.A[wildcard]> -> m
<m.B.__mypy-replace> -> m
<m.B.y> -> m
<m.B> -> m.B
<m.Z> -> m
Expand Down
41 changes: 41 additions & 0 deletions test-data/unit/plugins/add_overloaded_method.py
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import Callable

from mypy.nodes import ARG_POS, Argument, Var
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import MethodSpec, add_overloaded_method_to_class


class OverloadedMethodPlugin(Plugin):
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
if "AddOverloadedMethod" in fullname:
return add_overloaded_method_hook
return None


def add_overloaded_method_hook(ctx: ClassDefContext) -> None:
add_overloaded_method_to_class(ctx.api, ctx.cls, "method", _generate_method_specs(ctx))
add_overloaded_method_to_class(
ctx.api, ctx.cls, "clsmethod", _generate_method_specs(ctx), is_classmethod=True
)
add_overloaded_method_to_class(
ctx.api, ctx.cls, "stmethod", _generate_method_specs(ctx), is_staticmethod=True
)


def _generate_method_specs(ctx: ClassDefContext) -> list[MethodSpec]:
return [
MethodSpec(
args=[Argument(Var("arg"), ctx.api.named_type("builtins.int"), None, ARG_POS)],
return_type=ctx.api.named_type("builtins.str"),
),
MethodSpec(
args=[Argument(Var("arg"), ctx.api.named_type("builtins.str"), None, ARG_POS)],
return_type=ctx.api.named_type("builtins.int"),
),
]


def plugin(version: str) -> type[OverloadedMethodPlugin]:
return OverloadedMethodPlugin

0 comments on commit 9a35360

Please sign in to comment.